From 60949373db80ea613b265046708a5c0a3116147d Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 16 May 2026 21:17:42 +0800 Subject: [PATCH 01/98] wip --- bolt/core/QueryConfig.h | 18 ++ bolt/exec/Aggregate.h | 12 + bolt/exec/AggregateInfo.cpp | 2 + bolt/exec/AggregateInfo.h | 8 + bolt/exec/GroupingSet.cpp | 262 +++++++++++++++++++ bolt/exec/GroupingSet.h | 21 ++ bolt/jit/CMakeLists.txt | 1 + bolt/jit/aggregation/HashAggrJit.cpp | 373 +++++++++++++++++++++++++++ bolt/jit/aggregation/HashAggrJit.h | 81 ++++++ 9 files changed, 778 insertions(+) create mode 100644 bolt/jit/aggregation/HashAggrJit.cpp create mode 100644 bolt/jit/aggregation/HashAggrJit.h diff --git a/bolt/core/QueryConfig.h b/bolt/core/QueryConfig.h index db5566da4..1d96c5ebf 100644 --- a/bolt/core/QueryConfig.h +++ b/bolt/core/QueryConfig.h @@ -652,6 +652,12 @@ class QueryConfig { */ static constexpr const char* kJitLevel = "jit.level"; + static constexpr const char* kHashAggrJitEnabled = "hashaggr.jit.enabled"; + static constexpr const char* kHashAggrJitMinFuseWidth = + "hashaggr.jit.min_fuse_width"; + static constexpr const char* kHashAggrJitMaxFuseWidth = + "hashaggr.jit.max_fuse_width"; + // expired, to deleted later static constexpr const char* kBoltJitEnabled = "bolt.jit.enabled"; // For morsel-driven Bolt @@ -1606,6 +1612,18 @@ class QueryConfig { return flag & 1; } + bool enableHashAggrJit() const { + return get(kHashAggrJitEnabled, false); + } + + int32_t hashAggrJitMinFuseWidth() const { + return get(kHashAggrJitMinFuseWidth, 4); + } + + int32_t hashAggrJitMaxFuseWidth() const { + return get(kHashAggrJitMaxFuseWidth, 16); + } + int exceptionTraceLevel() const { return get(kExceptionTraceLevel, 1); } diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index 433abd8e5..a8422b787 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -66,6 +66,18 @@ class Aggregate { return resultType_; } + int32_t accumulatorOffset() const { + return offset_; + } + + int32_t accumulatorNullByte() const { + return nullByte_; + } + + uint8_t accumulatorNullMask() const { + return nullMask_; + } + // Returns the fixed number of bytes the accumulator takes on a group // row. Variable width accumulators will reference the variable // width part of the state from the fixed part. diff --git a/bolt/exec/AggregateInfo.cpp b/bolt/exec/AggregateInfo.cpp index 87ee2b655..0561087b6 100644 --- a/bolt/exec/AggregateInfo.cpp +++ b/bolt/exec/AggregateInfo.cpp @@ -71,6 +71,8 @@ std::vector toAggregateInfo( for (auto i = 0; i < numAggregates; i++) { const auto& aggregate = aggregationNode.aggregates()[i]; AggregateInfo info; + info.name = aggregate.call->name(); + info.rawInputTypes = aggregate.rawInputTypes; // Populate input. auto& channels = info.inputs; auto& constants = info.constantInputs; diff --git a/bolt/exec/AggregateInfo.h b/bolt/exec/AggregateInfo.h index d1c4d2ff9..e124a6984 100644 --- a/bolt/exec/AggregateInfo.h +++ b/bolt/exec/AggregateInfo.h @@ -41,6 +41,14 @@ struct AggregateInfo { /// Instance of the Aggregate class. std::unique_ptr function; + /// Function name used to create the aggregate. Kept for optional fast paths + /// such as aggregation JIT planning. + std::string name; + + /// Raw input types from the plan. Kept to identify JIT-able numeric variants + /// without depending on concrete Aggregate subclasses. + std::vector rawInputTypes; + /// Indices of the input columns in the input RowVector. std::vector inputs; diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 2c6e5bad4..b3cb42726 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -36,6 +36,9 @@ #include "bolt/exec/ContainerRow2RowSerde.h" #include "bolt/exec/OperatorUtils.h" #include "bolt/exec/RowToColumnVector.h" +#ifdef ENABLE_BOLT_JIT +#include "bolt/jit/aggregation/HashAggrJit.h" +#endif #include "bolt/type/Type.h" #include "bolt/vector/ComplexVector.h" @@ -59,6 +62,140 @@ bool areAllLazyNotLoaded(const std::vector& vectors) { }); } +#ifdef ENABLE_BOLT_JIT +std::string normalizedAggName(const std::string& name) { + static constexpr std::string_view kSparkPrefix{"spark_"}; + if (name.size() > kSparkPrefix.size() && + name.compare(0, kSparkPrefix.size(), kSparkPrefix) == 0) { + return name.substr(kSparkPrefix.size()); + } + return name; +} + +std::optional hashAggrJitKind(const std::string& name) { + const auto normalized = normalizedAggName(name); + if (normalized == "count") { + return jit::HashAggrJitKind::Count; + } + if (normalized == "sum") { + return jit::HashAggrJitKind::Sum; + } + if (normalized == "min") { + return jit::HashAggrJitKind::Min; + } + if (normalized == "max") { + return jit::HashAggrJitKind::Max; + } + if (normalized == "avg") { + return jit::HashAggrJitKind::Avg; + } + return std::nullopt; +} + +std::optional hashAggrJitValueKind(TypeKind kind) { + switch (kind) { + case TypeKind::TINYINT: + return jit::HashAggrJitValueKind::Int8; + case TypeKind::SMALLINT: + return jit::HashAggrJitValueKind::Int16; + case TypeKind::INTEGER: + return jit::HashAggrJitValueKind::Int32; + case TypeKind::BIGINT: + return jit::HashAggrJitValueKind::Int64; + case TypeKind::REAL: + return jit::HashAggrJitValueKind::Float; + case TypeKind::DOUBLE: + return jit::HashAggrJitValueKind::Double; + default: + return std::nullopt; + } +} + +bool isIntegralJitKind(jit::HashAggrJitValueKind kind) { + return kind == jit::HashAggrJitValueKind::Int8 || + kind == jit::HashAggrJitValueKind::Int16 || + kind == jit::HashAggrJitValueKind::Int32 || + kind == jit::HashAggrJitValueKind::Int64; +} + +std::optional makeHashAggrJitSlot( + int32_t aggregateIndex, + const AggregateInfo& aggregate) { + if (aggregate.distinct || aggregate.mask.has_value() || + !aggregate.sortingKeys.empty()) { + return std::nullopt; + } + + auto kind = hashAggrJitKind(aggregate.name); + if (!kind.has_value()) { + return std::nullopt; + } + + const bool countStar = *kind == jit::HashAggrJitKind::Count && + aggregate.inputs.empty(); + if (!countStar && aggregate.rawInputTypes.size() != 1) { + return std::nullopt; + } + + jit::HashAggrJitValueKind inputKind = jit::HashAggrJitValueKind::Int64; + if (!countStar) { + const auto& inputType = aggregate.rawInputTypes[0]; + if (inputType->isDecimal() || + !jit::isHashAggrJitSupportedType(inputType->kind())) { + return std::nullopt; + } + auto maybeInputKind = hashAggrJitValueKind(inputType->kind()); + if (!maybeInputKind.has_value()) { + return std::nullopt; + } + inputKind = *maybeInputKind; + } + + jit::HashAggrJitValueKind accumulatorKind = inputKind; + switch (*kind) { + case jit::HashAggrJitKind::Count: + accumulatorKind = jit::HashAggrJitValueKind::Int64; + break; + case jit::HashAggrJitKind::Sum: + accumulatorKind = (inputKind == jit::HashAggrJitValueKind::Float || + inputKind == jit::HashAggrJitValueKind::Double) + ? jit::HashAggrJitValueKind::Double + : jit::HashAggrJitValueKind::Int64; + break; + case jit::HashAggrJitKind::Avg: + accumulatorKind = jit::HashAggrJitValueKind::Double; + break; + case jit::HashAggrJitKind::Min: + case jit::HashAggrJitKind::Max: + // Keep floating-point min/max on the existing implementation to preserve + // Velox/Spark NaN ordering edge cases exactly. + if (!isIntegralJitKind(inputKind)) { + return std::nullopt; + } + accumulatorKind = inputKind; + break; + } + + return jit::HashAggrJitSlot{ + aggregateIndex, + *kind, + inputKind, + accumulatorKind, + aggregate.function->accumulatorOffset(), + aggregate.function->accumulatorNullByte(), + aggregate.function->accumulatorNullMask(), + countStar}; +} + +std::string hashAggrJitSignature(const jit::HashAggrJitSlot& slot) { + return fmt::format( + "{}_{}_{}", + static_cast(slot.kind), + jit::hashAggrJitValueKindName(slot.inputKind), + jit::hashAggrJitValueKindName(slot.accumulatorKind)); +} +#endif + } // namespace GroupingSet::GroupingSet( @@ -289,7 +426,14 @@ void GroupingSet::addInputForActiveRows( NanosecondTimer funcTimer(&stats_.aggFunctionTimeNs); auto* groups = lookup_->hits.data(); auto& newGroups = lookup_->newGroups; + std::vector jitExecuted; +#ifdef ENABLE_BOLT_JIT + runHashAggrJitChunks(groups, newGroups, input, mayPushdown, jitExecuted); +#endif for (auto i = 0; i < aggregates_.size(); ++i) { + if (!jitExecuted.empty() && jitExecuted[i]) { + continue; + } if (!aggregates_[i].sortingKeys.empty()) { continue; } @@ -419,6 +563,9 @@ void GroupingSet::createHashTable() { RowContainer& rows = *table_->rows(); initializeAggregates(aggregates_, rows, false); +#ifdef ENABLE_BOLT_JIT + maybeCreateHashAggrJitPlan(); +#endif auto numColumns = rows.keyTypes().size() + aggregates_.size(); @@ -771,6 +918,121 @@ const SelectivityVector& GroupingSet::getSelectivityVector( return *rows; } +#ifdef ENABLE_BOLT_JIT +void GroupingSet::maybeCreateHashAggrJitPlan() { + hashAggrJitChunks_.clear(); + if (!queryConfig_.enableHashAggrJit() || !isRawInput_ || isGlobal_ || + ignoreNullKeys_) { + return; + } + + const auto minFuseWidth = std::max(1, queryConfig_.hashAggrJitMinFuseWidth()); + const auto maxFuseWidth = std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); + std::unordered_map> groups; + for (auto i = 0; i < aggregates_.size(); ++i) { + auto slot = makeHashAggrJitSlot(i, aggregates_[i]); + if (!slot.has_value()) { + continue; + } + groups[hashAggrJitSignature(*slot)].push_back(*slot); + } + + for (auto& [_, slots] : groups) { + if (slots.size() < minFuseWidth) { + continue; + } + for (auto begin = 0; begin < slots.size(); begin += maxFuseWidth) { + const auto end = std::min(begin + maxFuseWidth, slots.size()); + std::vector chunkSlots( + slots.begin() + begin, slots.begin() + end); + if (chunkSlots.size() < minFuseWidth) { + continue; + } + jit::HashAggrJitChunk chunk(std::move(chunkSlots)); + if (chunk.codegen()) { + hashAggrJitChunks_.push_back(std::move(chunk)); + } + } + } +} + +void GroupingSet::runHashAggrJitChunks( + char** groups, + folly::Range newGroups, + const RowVectorPtr& input, + bool mayPushdown, + std::vector& jitExecuted) { + if (hashAggrJitChunks_.empty() || hasSpilled() || bypassProbeHT_ || + !activeRows_.isAllSelected()) { + return; + } + + jitExecuted.assign(aggregates_.size(), 0); + for (auto& chunk : hashAggrJitChunks_) { + if (!chunk.enabled()) { + continue; + } + + const auto numSlots = chunk.slots().size(); + hashAggrJitDecoded_.resize(numSlots); + hashAggrJitInputVectors_.assign(numSlots, nullptr); + hashAggrJitDecodedPtrs_.assign(numSlots, nullptr); + + bool canRunChunk = true; + for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { + const auto& slot = chunk.slots()[slotIndex]; + const auto& aggregate = aggregates_[slot.aggregateIndex]; + if (aggregate.mask.has_value() || aggregate.distinct || + !aggregate.sortingKeys.empty()) { + canRunChunk = false; + break; + } + const auto& rows = getSelectivityVector(slot.aggregateIndex); + if (&rows != &activeRows_ || !rows.hasSelections()) { + canRunChunk = false; + break; + } + if (slot.countStar) { + continue; + } + if (aggregate.inputs.size() != 1) { + canRunChunk = false; + break; + } + + VectorPtr arg; + if (aggregate.inputs[0] == kConstantChannel) { + arg = BaseVector::wrapInConstant(input->size(), 0, aggregate.constantInputs[0]); + } else { + arg = input->childAt(aggregate.inputs[0]); + } + if (mayPushdown && mayPushdown_[slot.aggregateIndex] && isLazyNotLoaded(*arg)) { + canRunChunk = false; + break; + } + hashAggrJitInputVectors_[slotIndex] = arg; + hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); + hashAggrJitDecodedPtrs_[slotIndex] = reinterpret_cast(&hashAggrJitDecoded_[slotIndex]); + } + + if (!canRunChunk) { + continue; + } + + if (!newGroups.empty()) { + for (const auto& slot : chunk.slots()) { + aggregates_[slot.aggregateIndex].function->initializeNewGroups(groups, newGroups); + } + } + + chunk.addDense(groups, activeRows_.end(), hashAggrJitDecodedPtrs_.data()); + for (const auto& slot : chunk.slots()) { + jitExecuted[slot.aggregateIndex] = 1; + } + } +} +#endif + bool GroupingSet::getOutput( int32_t maxOutputRows, int32_t maxOutputBytes, diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 3927165e6..02700cbfe 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -39,6 +39,10 @@ #include "bolt/exec/Spiller.h" #include "bolt/exec/TreeOfLosers.h" #include "bolt/exec/VectorHasher.h" +#ifdef ENABLE_BOLT_JIT +#include "bolt/jit/aggregation/HashAggrJit.h" +#endif +#include "bolt/vector/DecodedVector.h" namespace bytedance::bolt::exec { class GroupingSet { @@ -282,6 +286,16 @@ class GroupingSet { // index for this aggregation), otherwise it returns reference to activeRows_. const SelectivityVector& getSelectivityVector(size_t aggregateIndex) const; +#ifdef ENABLE_BOLT_JIT + void maybeCreateHashAggrJitPlan(); + void runHashAggrJitChunks( + char** groups, + folly::Range newGroups, + const RowVectorPtr& input, + bool mayPushdown, + std::vector& jitExecuted); +#endif + // Checks if input will fit in the existing memory and increases reservation // if not. If reservation cannot be increased, spills enough to make 'input' // fit. @@ -442,6 +456,13 @@ class GroupingSet { std::unique_ptr sortedAggregations_; std::vector> distinctAggregations_; +#ifdef ENABLE_BOLT_JIT + std::vector hashAggrJitChunks_; + std::vector hashAggrJitDecoded_; + std::vector hashAggrJitInputVectors_; + std::vector hashAggrJitDecodedPtrs_; +#endif + // True if any aggregate accumulator allocates memory outside RowContainer's // HashStringAllocator (e.g. directly from MemoryPool). In that case, // RowContainer::estimateRowSize() can under-estimate the actual per-group diff --git a/bolt/jit/CMakeLists.txt b/bolt/jit/CMakeLists.txt index cade0033d..a0af680ce 100644 --- a/bolt/jit/CMakeLists.txt +++ b/bolt/jit/CMakeLists.txt @@ -16,6 +16,7 @@ bolt_add_library( bolt_thrustjit CompiledModule.cpp ThrustJITv2.cpp + aggregation/HashAggrJit.cpp RowContainer/RowContainerCodeGenerator.cpp RowContainer/RowEqVectorsCodeGenerator.cpp ) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp new file mode 100644 index 000000000..c49b6e697 --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -0,0 +1,373 @@ +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +#include +#include +#include + +#include + +#include "bolt/jit/ThrustJITv2.h" + +namespace bytedance::bolt::jit { +namespace { + +llvm::FunctionCallee declareFunction( + llvm::Module& module, + llvm::StringRef name, + llvm::Type* returnType, + llvm::ArrayRef argTypes) { + return module.getOrInsertFunction( + name, llvm::FunctionType::get(returnType, argTypes, false)); +} + +void ensureBuiltinDeclarations(llvm::Module& module) { + auto& context = module.getContext(); + auto* i8Ty = llvm::Type::getInt8Ty(context); + auto* i16Ty = llvm::Type::getInt16Ty(context); + auto* i32Ty = llvm::Type::getInt32Ty(context); + auto* i64Ty = llvm::Type::getInt64Ty(context); + auto* floatTy = llvm::Type::getFloatTy(context); + auto* doubleTy = llvm::Type::getDoubleTy(context); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + + declareFunction(module, "jit_GetDecodedValueI8", i8Ty, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_GetDecodedValueI16", i16Ty, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_GetDecodedValueI32", i32Ty, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_GetDecodedValueI64", i64Ty, {i8PtrTy, i32Ty}); + declareFunction( + module, "jit_GetDecodedValueFloat", floatTy, {i8PtrTy, i32Ty}); + declareFunction( + module, "jit_GetDecodedValueDouble", doubleTy, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_GetDecodedIsNull", i8Ty, {i8PtrTy, i32Ty}); +} + +llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Int8: + return builder.getInt8Ty(); + case HashAggrJitValueKind::Int16: + return builder.getInt16Ty(); + case HashAggrJitValueKind::Int32: + return builder.getInt32Ty(); + case HashAggrJitValueKind::Int64: + return builder.getInt64Ty(); + case HashAggrJitValueKind::Float: + return builder.getFloatTy(); + case HashAggrJitValueKind::Double: + return builder.getDoubleTy(); + } + return builder.getInt64Ty(); +} + +std::string decodedValueFunction(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Int8: + return "jit_GetDecodedValueI8"; + case HashAggrJitValueKind::Int16: + return "jit_GetDecodedValueI16"; + case HashAggrJitValueKind::Int32: + return "jit_GetDecodedValueI32"; + case HashAggrJitValueKind::Int64: + return "jit_GetDecodedValueI64"; + case HashAggrJitValueKind::Float: + return "jit_GetDecodedValueFloat"; + case HashAggrJitValueKind::Double: + return "jit_GetDecodedValueDouble"; + } + return "jit_GetDecodedValueI64"; +} + +bool isFloatKind(HashAggrJitValueKind kind) { + return kind == HashAggrJitValueKind::Float || + kind == HashAggrJitValueKind::Double; +} + +llvm::LoadInst* loadValue( + llvm::IRBuilder<>& builder, + llvm::Value* row, + llvm::Type* type, + int32_t offset) { + auto* addr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), row, static_cast(offset)); + auto* castAddr = builder.CreatePointerCast(addr, type->getPointerTo()); + auto* load = builder.CreateLoad(type, castAddr); + load->setAlignment(llvm::Align(1)); + return load; +} + +void storeValue( + llvm::IRBuilder<>& builder, + llvm::Value* row, + llvm::Type* type, + int32_t offset, + llvm::Value* value) { + auto* addr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), row, static_cast(offset)); + auto* castAddr = builder.CreatePointerCast(addr, type->getPointerTo()); + auto* store = builder.CreateStore(value, castAddr); + store->setAlignment(llvm::Align(1)); +} + +llvm::Value* castValue( + llvm::IRBuilder<>& builder, + llvm::Value* value, + HashAggrJitValueKind from, + HashAggrJitValueKind to) { + if (from == to) { + return value; + } + auto* toType = llvmType(builder, to); + if (isFloatKind(from) && isFloatKind(to)) { + return builder.CreateFPCast(value, toType); + } + if (!isFloatKind(from) && isFloatKind(to)) { + return builder.CreateSIToFP(value, toType); + } + if (isFloatKind(from) && !isFloatKind(to)) { + return builder.CreateFPToSI(value, toType); + } + return builder.CreateSExtOrTrunc(value, toType); +} + +llvm::Value* isAccumulatorNull( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* byte = loadValue(builder, group, builder.getInt8Ty(), slot.nullByte); + auto* mask = llvm::ConstantInt::get(builder.getInt8Ty(), slot.nullMask); + return builder.CreateICmpNE( + builder.CreateAnd(byte, mask), builder.getInt8(0)); +} + +void clearAccumulatorNull( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* byte = loadValue(builder, group, builder.getInt8Ty(), slot.nullByte); + auto* mask = llvm::ConstantInt::get( + builder.getInt8Ty(), static_cast(~slot.nullMask)); + storeValue( + builder, + group, + builder.getInt8Ty(), + slot.nullByte, + builder.CreateAnd(byte, mask)); +} + +llvm::Value* loadDecodedValue( + llvm::IRBuilder<>& builder, + llvm::Module& module, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot) { + auto* callee = module.getFunction(decodedValueFunction(slot.inputKind)); + return builder.CreateCall(callee, {decoded, row}); +} + +void genCountUpdate( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* state = loadValue(builder, group, builder.getInt64Ty(), slot.offset); + storeValue( + builder, + group, + builder.getInt64Ty(), + slot.offset, + builder.CreateAdd(state, builder.getInt64(1))); +} + +void genNonNullUpdate( + llvm::IRBuilder<>& builder, + llvm::Value* group, + llvm::Value* rawValue, + const HashAggrJitSlot& slot) { + auto* accType = llvmType(builder, slot.accumulatorKind); + auto* value = castValue(builder, rawValue, slot.inputKind, slot.accumulatorKind); + switch (slot.kind) { + case HashAggrJitKind::Sum: { + clearAccumulatorNull(builder, group, slot); + auto* oldValue = loadValue(builder, group, accType, slot.offset); + auto* newValue = isFloatKind(slot.accumulatorKind) + ? builder.CreateFAdd(oldValue, value) + : builder.CreateAdd(oldValue, value); + storeValue(builder, group, accType, slot.offset, newValue); + break; + } + case HashAggrJitKind::Avg: { + clearAccumulatorNull(builder, group, slot); + auto* oldSum = loadValue(builder, group, accType, slot.offset); + auto* newSum = builder.CreateFAdd(oldSum, value); + storeValue(builder, group, accType, slot.offset, newSum); + auto* oldCount = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); + storeValue( + builder, + group, + builder.getInt64Ty(), + slot.offset + 8, + builder.CreateAdd(oldCount, builder.getInt64(1))); + break; + } + case HashAggrJitKind::Min: + case HashAggrJitKind::Max: { + auto* oldValue = loadValue(builder, group, accType, slot.offset); + auto* nullState = isAccumulatorNull(builder, group, slot); + auto* better = slot.kind == HashAggrJitKind::Min + ? builder.CreateICmpSLT(value, oldValue) + : builder.CreateICmpSGT(value, oldValue); + auto* shouldStore = builder.CreateOr(nullState, better); + auto* selected = builder.CreateSelect(shouldStore, value, oldValue); + storeValue(builder, group, accType, slot.offset, selected); + clearAccumulatorNull(builder, group, slot); + break; + } + case HashAggrJitKind::Count: + genCountUpdate(builder, group, slot); + break; + } +} + +bool genAddDenseIR(llvm::Module& module, const std::string& fn, const std::vector& slots) { + ensureBuiltinDeclarations(module); + auto& context = module.getContext(); + llvm::IRBuilder<> builder(context); + auto* voidTy = builder.getVoidTy(); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); + auto* i32Ty = builder.getInt32Ty(); + auto* funcTy = llvm::FunctionType::get(voidTy, {i8PtrPtrTy, i32Ty, i8PtrPtrTy}, false); + auto* func = llvm::Function::Create(funcTy, llvm::Function::ExternalLinkage, fn, module); + auto argIt = func->arg_begin(); + llvm::Value* groups = &*argIt++; + groups->setName("groups"); + llvm::Value* numRows = &*argIt++; + numRows->setName("num_rows"); + llvm::Value* decodedInputs = &*argIt++; + decodedInputs->setName("decoded_inputs"); + + auto* entry = llvm::BasicBlock::Create(context, "entry", func); + auto* loop = llvm::BasicBlock::Create(context, "loop", func); + auto* end = llvm::BasicBlock::Create(context, "end", func); + builder.SetInsertPoint(entry); + builder.CreateCondBr(builder.CreateICmpSLE(numRows, builder.getInt32(0)), end, loop); + + builder.SetInsertPoint(loop); + auto* row = builder.CreatePHI(i32Ty, 2, "row"); + row->addIncoming(builder.getInt32(0), entry); + auto* groupAddr = builder.CreateInBoundsGEP(i8PtrTy, groups, row); + auto* group = builder.CreateLoad(i8PtrTy, groupAddr); + + for (auto i = 0; i < slots.size(); ++i) { + const auto& slot = slots[i]; + if (slot.kind == HashAggrJitKind::Count && slot.countStar) { + genCountUpdate(builder, group, slot); + continue; + } + + auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); + auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); + auto* isNull = builder.CreateICmpNE( + builder.CreateCall(module.getFunction("jit_GetDecodedIsNull"), {decoded, row}), + builder.getInt8(0)); + auto* updateBlock = llvm::BasicBlock::Create(context, "slot_update", func, end); + auto* nextBlock = llvm::BasicBlock::Create(context, "slot_next", func, end); + builder.CreateCondBr(isNull, nextBlock, updateBlock); + + builder.SetInsertPoint(updateBlock); + if (slot.kind == HashAggrJitKind::Count) { + genCountUpdate(builder, group, slot); + } else { + auto* value = loadDecodedValue(builder, module, decoded, row, slot); + genNonNullUpdate(builder, group, value, slot); + } + builder.CreateBr(nextBlock); + builder.SetInsertPoint(nextBlock); + } + + auto* next = builder.CreateAdd(row, builder.getInt32(1)); + row->addIncoming(next, builder.GetInsertBlock()); + builder.CreateCondBr(builder.CreateICmpSLT(next, numRows), loop, end); + + builder.SetInsertPoint(end); + builder.CreateRetVoid(); + + return llvm::verifyFunction(*func, &llvm::errs()); +} + +} // namespace + +HashAggrJitChunk::HashAggrJitChunk(std::vector slots) + : slots_(std::move(slots)) {} + +std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Int8: + return "i8"; + case HashAggrJitValueKind::Int16: + return "i16"; + case HashAggrJitValueKind::Int32: + return "i32"; + case HashAggrJitValueKind::Int64: + return "i64"; + case HashAggrJitValueKind::Float: + return "f32"; + case HashAggrJitValueKind::Double: + return "f64"; + } + return "unknown"; +} + +bool isHashAggrJitSupportedType(TypeKind kind) { + switch (kind) { + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + case TypeKind::REAL: + case TypeKind::DOUBLE: + return true; + default: + return false; + } +} + +std::string HashAggrJitChunk::functionName() const { + std::ostringstream out; + out << "jit_hashaggr_add_dense_v1_n" << slots_.size(); + for (const auto& slot : slots_) { + out << "_" << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) + << hashAggrJitValueKindName(slot.accumulatorKind) << "o" << slot.offset + << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) + << (slot.countStar ? "s" : "x"); + } + return out.str(); +} + +bool HashAggrJitChunk::codegen() { + if (addDense_) { + return true; + } + auto* jit = ThrustJITv2::getInstance(); + if (jit == nullptr) { + return false; + } + const auto fn = functionName(); + module_ = jit->CompileModule( + [&](llvm::Module& module) { return genAddDenseIR(module, fn, slots_); }, fn); + if (!module_) { + disabled_ = true; + return false; + } + addDense_ = reinterpret_cast(module_->getFuncPtr(fn)); + if (addDense_ == nullptr) { + disabled_ = true; + return false; + } + return true; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h new file mode 100644 index 000000000..aa9b52f7b --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -0,0 +1,81 @@ +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include +#include +#include +#include + +#include "bolt/jit/CompiledModule.h" +#include "bolt/type/Type.h" + +namespace bytedance::bolt::jit { + +enum class HashAggrJitKind : uint8_t { + Count, + Sum, + Min, + Max, + Avg, +}; + +enum class HashAggrJitValueKind : uint8_t { + Int8, + Int16, + Int32, + Int64, + Float, + Double, +}; + +struct HashAggrJitSlot { + int32_t aggregateIndex; + HashAggrJitKind kind; + HashAggrJitValueKind inputKind; + HashAggrJitValueKind accumulatorKind; + int32_t offset; + int32_t nullByte; + uint8_t nullMask; + bool countStar{false}; +}; + +using HashAggrJitAddDenseFunc = void (*)(char** groups, int32_t numRows, char** decodedInputs); + +class HashAggrJitChunk { + public: + explicit HashAggrJitChunk(std::vector slots); + + bool codegen(); + + bool enabled() const { + return addDense_ != nullptr && !disabled_; + } + + void disable() { + disabled_ = true; + } + + void addDense(char** groups, int32_t numRows, char** decodedInputs) const { + addDense_(groups, numRows, decodedInputs); + } + + const std::vector& slots() const { + return slots_; + } + + std::string functionName() const; + + private: + std::vector slots_; + CompiledModuleSP module_; + HashAggrJitAddDenseFunc addDense_{nullptr}; + bool disabled_{false}; +}; + +bool isHashAggrJitSupportedType(TypeKind kind); +std::string hashAggrJitValueKindName(HashAggrJitValueKind kind); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT From 913688c491362ae6aa1035d7d5845e598448aeaf Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 16 May 2026 22:55:03 +0800 Subject: [PATCH 02/98] wip2 --- bolt/core/QueryConfig.h | 6 + bolt/exec/GroupingSet.cpp | 24 ++- bolt/exec/GroupingSet.h | 1 + bolt/exec/benchmarks/CMakeLists.txt | 9 ++ bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 146 ++++++++++++++++++ bolt/jit/aggregation/HashAggrJit.cpp | 143 +++++++++++++++-- bolt/jit/aggregation/HashAggrJit.h | 19 ++- 7 files changed, 329 insertions(+), 19 deletions(-) create mode 100644 bolt/exec/benchmarks/HashAggrJitBenchmark.cpp diff --git a/bolt/core/QueryConfig.h b/bolt/core/QueryConfig.h index 1d96c5ebf..de88fd923 100644 --- a/bolt/core/QueryConfig.h +++ b/bolt/core/QueryConfig.h @@ -657,6 +657,8 @@ class QueryConfig { "hashaggr.jit.min_fuse_width"; static constexpr const char* kHashAggrJitMaxFuseWidth = "hashaggr.jit.max_fuse_width"; + static constexpr const char* kHashAggrJitCompileMinCount = + "hashaggr.jit.compile_min_count"; // expired, to deleted later static constexpr const char* kBoltJitEnabled = "bolt.jit.enabled"; @@ -1624,6 +1626,10 @@ class QueryConfig { return get(kHashAggrJitMaxFuseWidth, 16); } + int32_t hashAggrJitCompileMinCount() const { + return get(kHashAggrJitCompileMinCount, 3); + } + int exceptionTraceLevel() const { return get(kExceptionTraceLevel, 1); } diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index b3cb42726..b5ded79f1 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -928,6 +928,9 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto minFuseWidth = std::max(1, queryConfig_.hashAggrJitMinFuseWidth()); const auto maxFuseWidth = std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); + const auto compileMinCount = + std::max(1, queryConfig_.hashAggrJitCompileMinCount()); + const auto minChunkWidth = std::max(minFuseWidth, compileMinCount); std::unordered_map> groups; for (auto i = 0; i < aggregates_.size(); ++i) { auto slot = makeHashAggrJitSlot(i, aggregates_[i]); @@ -938,14 +941,14 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { } for (auto& [_, slots] : groups) { - if (slots.size() < minFuseWidth) { + if (slots.size() < minChunkWidth) { continue; } for (auto begin = 0; begin < slots.size(); begin += maxFuseWidth) { const auto end = std::min(begin + maxFuseWidth, slots.size()); std::vector chunkSlots( slots.begin() + begin, slots.begin() + end); - if (chunkSlots.size() < minFuseWidth) { + if (chunkSlots.size() < minChunkWidth) { continue; } jit::HashAggrJitChunk chunk(std::move(chunkSlots)); @@ -963,7 +966,7 @@ void GroupingSet::runHashAggrJitChunks( bool mayPushdown, std::vector& jitExecuted) { if (hashAggrJitChunks_.empty() || hasSpilled() || bypassProbeHT_ || - !activeRows_.isAllSelected()) { + supportRowBasedOutput_ || !activeRows_.isAllSelected()) { return; } @@ -979,6 +982,7 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitDecodedPtrs_.assign(numSlots, nullptr); bool canRunChunk = true; + bool inputsMayHaveNulls = false; for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { const auto& slot = chunk.slots()[slotIndex]; const auto& aggregate = aggregates_[slot.aggregateIndex]; @@ -1012,6 +1016,8 @@ void GroupingSet::runHashAggrJitChunks( } hashAggrJitInputVectors_[slotIndex] = arg; hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); + inputsMayHaveNulls = + inputsMayHaveNulls || hashAggrJitDecoded_[slotIndex].mayHaveNulls(); hashAggrJitDecodedPtrs_[slotIndex] = reinterpret_cast(&hashAggrJitDecoded_[slotIndex]); } @@ -1020,12 +1026,18 @@ void GroupingSet::runHashAggrJitChunks( } if (!newGroups.empty()) { - for (const auto& slot : chunk.slots()) { - aggregates_[slot.aggregateIndex].function->initializeNewGroups(groups, newGroups); + hashAggrJitNewGroups_.resize(newGroups.size()); + for (auto i = 0; i < newGroups.size(); ++i) { + hashAggrJitNewGroups_[i] = groups[newGroups[i]]; } + chunk.init(hashAggrJitNewGroups_.data(), newGroups.size()); } - chunk.addDense(groups, activeRows_.end(), hashAggrJitDecodedPtrs_.data()); + chunk.addDense( + groups, + activeRows_.end(), + hashAggrJitDecodedPtrs_.data(), + inputsMayHaveNulls); for (const auto& slot : chunk.slots()) { jitExecuted[slot.aggregateIndex] = 1; } diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 02700cbfe..9689c63de 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -461,6 +461,7 @@ class GroupingSet { std::vector hashAggrJitDecoded_; std::vector hashAggrJitInputVectors_; std::vector hashAggrJitDecodedPtrs_; + std::vector hashAggrJitNewGroups_; #endif // True if any aggregate accumulator allocates memory outside RowContainer's diff --git a/bolt/exec/benchmarks/CMakeLists.txt b/bolt/exec/benchmarks/CMakeLists.txt index dfeef8cd5..de2d6a864 100644 --- a/bolt/exec/benchmarks/CMakeLists.txt +++ b/bolt/exec/benchmarks/CMakeLists.txt @@ -70,6 +70,15 @@ target_link_libraries( GTest::gtest_main ) +add_executable(bolt_hashaggr_jit_benchmark HashAggrJitBenchmark.cpp) + +target_link_libraries( + bolt_hashaggr_jit_benchmark PRIVATE + bolt_testutils + ${FOLLY_BENCHMARK} + GTest::gtest_main +) + add_executable(bolt_sort_random_data_benchmark SortRandomDataBenchmark.cpp) add_executable(bolt_sort_window_benchmark SortWindowBenchmark.cpp) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp new file mode 100644 index 000000000..598f74652 --- /dev/null +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include "bolt/core/QueryConfig.h" +#include "bolt/exec/tests/utils/AssertQueryBuilder.h" +#include "bolt/exec/tests/utils/PlanBuilder.h" +#include "bolt/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "bolt/functions/sparksql/aggregates/Register.h" +#include "bolt/parse/TypeResolver.h" +#include "bolt/vector/tests/utils/VectorTestBase.h" + +using namespace bytedance::bolt; +using namespace bytedance::bolt::exec; +using namespace bytedance::bolt::test; + +DEFINE_int32(hashaggr_jit_benchmark_batches, 20, "Number of input batches."); +DEFINE_int32(hashaggr_jit_benchmark_batch_size, 10000, "Rows per input batch."); +DEFINE_int32(hashaggr_jit_benchmark_groups, 10000, "Number of distinct groups."); + +namespace { + +struct HashAggrJitBenchmarkCase { + std::shared_ptr plan; +}; + +class HashAggrJitBenchmark : public VectorTestBase { + public: + void addBenchmark(const std::string& name, int32_t width) { + auto rows = makeRows(width); + std::vector sums; + std::vector avgs; + std::vector mins; + std::vector counts; + sums.reserve(width); + avgs.reserve(width); + mins.reserve(width); + counts.reserve(width); + for (auto i = 0; i < width; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + mins.push_back(fmt::format("min(c{})", i + 1)); + counts.push_back(fmt::format("count(c{})", i + 1)); + } + + addCase(name + "_sum", rows, sums); + addCase(name + "_avg", rows, avgs); + addCase(name + "_min", rows, mins); + addCase(name + "_count", rows, counts); + } + + private: + std::vector makeRows(int32_t width) { + std::vector names; + std::vector children; + names.reserve(width + 1); + children.reserve(width + 1); + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [](vector_size_t row) { return row % FLAGS_hashaggr_jit_benchmark_groups; })); + + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [column](vector_size_t row) { + return static_cast((row + 17 * column) & 0xffff); + })); + } + + auto batch = makeRowVector(names, children); + std::vector rows; + rows.reserve(FLAGS_hashaggr_jit_benchmark_batches); + for (auto i = 0; i < FLAGS_hashaggr_jit_benchmark_batches; ++i) { + rows.push_back(batch); + } + return rows; + } + + std::shared_ptr makePlan( + const std::vector& rows, + const std::vector& aggregates) { + return exec::test::PlanBuilder() + .values(rows) + .singleAggregation({"c0"}, aggregates) + .planNode(); + } + + void run(const std::shared_ptr& plan, bool enableJit) { + exec::test::AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, enableJit ? "true" : "false") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "4") + .config(core::QueryConfig::kHashAggrJitMaxFuseWidth, "16") + .config(core::QueryConfig::kHashAggrJitCompileMinCount, "3") + .copyResults(pool_.get()); + } + + void addCase( + const std::string& name, + const std::vector& rows, + const std::vector& aggregates) { + auto testCase = std::make_unique(); + testCase->plan = makePlan(rows, aggregates); + // Warm up both paths so the benchmark compares steady-state execution and + // doesn't charge one-time plan setup / JIT compilation to the first sample. + run(testCase->plan, false); + run(testCase->plan, true); + auto* testCasePtr = testCase.get(); + folly::addBenchmark(__FILE__, name + "_nojit", [this, testCasePtr]() { + run(testCasePtr->plan, false); + return 1; + }); + folly::addBenchmark(__FILE__, name + "_jit", [this, testCasePtr]() { + run(testCasePtr->plan, true); + return 1; + }); + folly::addBenchmark(__FILE__, "-", []() { return 0; }); + cases_.push_back(std::move(testCase)); + } + + std::vector> cases_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::init(&argc, &argv); + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + aggregate::prestosql::registerAllAggregateFunctions(); + functions::aggregate::sparksql::registerAggregateFunctions("spark_", false); + parse::registerTypeResolver(); + + HashAggrJitBenchmark benchmark; + benchmark.addBenchmark("width4", 4); + benchmark.addBenchmark("width8", 8); + benchmark.addBenchmark("width16", 16); + benchmark.addBenchmark("width32", 32); + + folly::runBenchmarks(); + return 0; +} diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index c49b6e697..56c5e5130 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -156,6 +156,20 @@ void clearAccumulatorNull( builder.CreateAnd(byte, mask)); } +void setAccumulatorNull( + llvm::IRBuilder<>& builder, + llvm::Value* group, + const HashAggrJitSlot& slot) { + auto* byte = loadValue(builder, group, builder.getInt8Ty(), slot.nullByte); + auto* mask = llvm::ConstantInt::get(builder.getInt8Ty(), slot.nullMask); + storeValue( + builder, + group, + builder.getInt8Ty(), + slot.nullByte, + builder.CreateOr(byte, mask)); +} + llvm::Value* loadDecodedValue( llvm::IRBuilder<>& builder, llvm::Module& module, @@ -229,7 +243,89 @@ void genNonNullUpdate( } } -bool genAddDenseIR(llvm::Module& module, const std::string& fn, const std::vector& slots) { +bool genAddDenseIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots, + bool checkInputNulls); + +bool genInitIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots) { + auto& context = module.getContext(); + llvm::IRBuilder<> builder(context); + auto* voidTy = builder.getVoidTy(); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); + auto* i32Ty = builder.getInt32Ty(); + auto* funcTy = llvm::FunctionType::get(voidTy, {i8PtrPtrTy, i32Ty}, false); + auto* func = llvm::Function::Create( + funcTy, llvm::Function::ExternalLinkage, fn, module); + auto argIt = func->arg_begin(); + llvm::Value* newGroups = &*argIt++; + newGroups->setName("new_groups"); + llvm::Value* numNewGroups = &*argIt++; + numNewGroups->setName("num_new_groups"); + + auto* entry = llvm::BasicBlock::Create(context, "entry", func); + auto* loop = llvm::BasicBlock::Create(context, "loop", func); + auto* end = llvm::BasicBlock::Create(context, "end", func); + builder.SetInsertPoint(entry); + builder.CreateCondBr( + builder.CreateICmpSLE(numNewGroups, builder.getInt32(0)), end, loop); + + builder.SetInsertPoint(loop); + auto* index = builder.CreatePHI(i32Ty, 2, "idx"); + index->addIncoming(builder.getInt32(0), entry); + auto* groupAddr = builder.CreateInBoundsGEP(i8PtrTy, newGroups, index); + auto* group = builder.CreateLoad(i8PtrTy, groupAddr); + + for (const auto& slot : slots) { + if (slot.kind != HashAggrJitKind::Count) { + setAccumulatorNull(builder, group, slot); + } + auto* accType = llvmType(builder, slot.accumulatorKind); + if (isFloatKind(slot.accumulatorKind)) { + storeValue( + builder, + group, + accType, + slot.offset, + llvm::ConstantFP::get(accType, 0.0)); + } else { + storeValue( + builder, + group, + accType, + slot.offset, + llvm::ConstantInt::get(accType, 0)); + } + if (slot.kind == HashAggrJitKind::Avg) { + storeValue( + builder, + group, + builder.getInt64Ty(), + slot.offset + 8, + builder.getInt64(0)); + } + } + + auto* next = builder.CreateAdd(index, builder.getInt32(1)); + index->addIncoming(next, builder.GetInsertBlock()); + builder.CreateCondBr(builder.CreateICmpSLT(next, numNewGroups), loop, end); + + builder.SetInsertPoint(end); + builder.CreateRetVoid(); + + return llvm::verifyFunction(*func, &llvm::errs()); +} + +bool genAddDenseIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots, + bool checkInputNulls) { ensureBuiltinDeclarations(module); auto& context = module.getContext(); llvm::IRBuilder<> builder(context); @@ -266,14 +362,18 @@ bool genAddDenseIR(llvm::Module& module, const std::string& fn, const std::vecto continue; } - auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); - auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); - auto* isNull = builder.CreateICmpNE( - builder.CreateCall(module.getFunction("jit_GetDecodedIsNull"), {decoded, row}), - builder.getInt8(0)); auto* updateBlock = llvm::BasicBlock::Create(context, "slot_update", func, end); auto* nextBlock = llvm::BasicBlock::Create(context, "slot_next", func, end); - builder.CreateCondBr(isNull, nextBlock, updateBlock); + auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); + auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); + if (checkInputNulls) { + auto* isNull = builder.CreateICmpNE( + builder.CreateCall(module.getFunction("jit_GetDecodedIsNull"), {decoded, row}), + builder.getInt8(0)); + builder.CreateCondBr(isNull, nextBlock, updateBlock); + } else { + builder.CreateBr(updateBlock); + } builder.SetInsertPoint(updateBlock); if (slot.kind == HashAggrJitKind::Count) { @@ -335,7 +435,7 @@ bool isHashAggrJitSupportedType(TypeKind kind) { std::string HashAggrJitChunk::functionName() const { std::ostringstream out; - out << "jit_hashaggr_add_dense_v1_n" << slots_.size(); + out << "jit_hashaggr_v2_n" << slots_.size(); for (const auto& slot : slots_) { out << "_" << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) << hashAggrJitValueKindName(slot.accumulatorKind) << "o" << slot.offset @@ -345,6 +445,14 @@ std::string HashAggrJitChunk::functionName() const { return out.str(); } +std::string HashAggrJitChunk::initFunctionName() const { + return functionName() + "_init"; +} + +std::string HashAggrJitChunk::addDenseNoNullFunctionName() const { + return functionName() + "_add_dense_no_null"; +} + bool HashAggrJitChunk::codegen() { if (addDense_) { return true; @@ -353,15 +461,26 @@ bool HashAggrJitChunk::codegen() { if (jit == nullptr) { return false; } - const auto fn = functionName(); + const auto moduleKey = functionName(); + const auto initFn = initFunctionName(); + const auto addFn = moduleKey + "_add_dense"; + const auto addNoNullFn = addDenseNoNullFunctionName(); module_ = jit->CompileModule( - [&](llvm::Module& module) { return genAddDenseIR(module, fn, slots_); }, fn); + [&](llvm::Module& module) { + return genInitIR(module, initFn, slots_) || + genAddDenseIR(module, addFn, slots_, true) || + genAddDenseIR(module, addNoNullFn, slots_, false); + }, + moduleKey); if (!module_) { disabled_ = true; return false; } - addDense_ = reinterpret_cast(module_->getFuncPtr(fn)); - if (addDense_ == nullptr) { + init_ = reinterpret_cast(module_->getFuncPtr(initFn)); + addDense_ = reinterpret_cast(module_->getFuncPtr(addFn)); + addDenseNoNull_ = reinterpret_cast( + module_->getFuncPtr(addNoNullFn)); + if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr) { disabled_ = true; return false; } diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index aa9b52f7b..b1f9e39be 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -41,6 +41,7 @@ struct HashAggrJitSlot { }; using HashAggrJitAddDenseFunc = void (*)(char** groups, int32_t numRows, char** decodedInputs); +using HashAggrJitInitFunc = void (*)(char** newGroups, int32_t numNewGroups); class HashAggrJitChunk { public: @@ -56,7 +57,19 @@ class HashAggrJitChunk { disabled_ = true; } - void addDense(char** groups, int32_t numRows, char** decodedInputs) const { + void init(char** newGroups, int32_t numNewGroups) const { + init_(newGroups, numNewGroups); + } + + void addDense( + char** groups, + int32_t numRows, + char** decodedInputs, + bool inputsMayHaveNulls) const { + if (!inputsMayHaveNulls && addDenseNoNull_ != nullptr) { + addDenseNoNull_(groups, numRows, decodedInputs); + return; + } addDense_(groups, numRows, decodedInputs); } @@ -65,11 +78,15 @@ class HashAggrJitChunk { } std::string functionName() const; + std::string initFunctionName() const; + std::string addDenseNoNullFunctionName() const; private: std::vector slots_; CompiledModuleSP module_; + HashAggrJitInitFunc init_{nullptr}; HashAggrJitAddDenseFunc addDense_{nullptr}; + HashAggrJitAddDenseFunc addDenseNoNull_{nullptr}; bool disabled_{false}; }; From dccaad384fb2f9111400531e8288f7667e367b28 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 20 May 2026 15:01:04 +0800 Subject: [PATCH 03/98] wip --- bolt/exec/GroupingSet.cpp | 132 +++++-- bolt/exec/GroupingSet.h | 6 + bolt/exec/RowContainer.cpp | 85 +++++ bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 95 ++++- .../aggregates/tests/SumAggregationTest.cpp | 64 ++++ bolt/jit/aggregation/HashAggrJit.cpp | 343 +++++++++++++++++- bolt/jit/aggregation/HashAggrJit.h | 12 + 7 files changed, 693 insertions(+), 44 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index b5ded79f1..e59965832 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -102,6 +102,8 @@ std::optional hashAggrJitValueKind(TypeKind kind) { return jit::HashAggrJitValueKind::Int32; case TypeKind::BIGINT: return jit::HashAggrJitValueKind::Int64; + case TypeKind::HUGEINT: + return jit::HashAggrJitValueKind::Int128; case TypeKind::REAL: return jit::HashAggrJitValueKind::Float; case TypeKind::DOUBLE: @@ -111,16 +113,10 @@ std::optional hashAggrJitValueKind(TypeKind kind) { } } -bool isIntegralJitKind(jit::HashAggrJitValueKind kind) { - return kind == jit::HashAggrJitValueKind::Int8 || - kind == jit::HashAggrJitValueKind::Int16 || - kind == jit::HashAggrJitValueKind::Int32 || - kind == jit::HashAggrJitValueKind::Int64; -} - std::optional makeHashAggrJitSlot( int32_t aggregateIndex, - const AggregateInfo& aggregate) { + const AggregateInfo& aggregate, + bool isRawInput) { if (aggregate.distinct || aggregate.mask.has_value() || !aggregate.sortingKeys.empty()) { return std::nullopt; @@ -132,23 +128,47 @@ std::optional makeHashAggrJitSlot( } const bool countStar = *kind == jit::HashAggrJitKind::Count && - aggregate.inputs.empty(); - if (!countStar && aggregate.rawInputTypes.size() != 1) { + aggregate.inputs.empty() && isRawInput; + if (!countStar && aggregate.inputs.size() != 1) { return std::nullopt; } jit::HashAggrJitValueKind inputKind = jit::HashAggrJitValueKind::Int64; + bool decimal = false; if (!countStar) { - const auto& inputType = aggregate.rawInputTypes[0]; - if (inputType->isDecimal() || + const auto& inputType = isRawInput ? aggregate.rawInputTypes[0] + : aggregate.intermediateType; + decimal = isRawInput && inputType->isDecimal() && + (*kind == jit::HashAggrJitKind::Sum || *kind == jit::HashAggrJitKind::Avg); + if (decimal) { + auto maybeInputKind = hashAggrJitValueKind(inputType->kind()); + if (!maybeInputKind.has_value() || + (*maybeInputKind != jit::HashAggrJitValueKind::Int64 && + *maybeInputKind != jit::HashAggrJitValueKind::Int128)) { + return std::nullopt; + } + inputKind = *maybeInputKind; + } else if (!isRawInput && *kind == jit::HashAggrJitKind::Avg) { + if (!inputType->isRow() || inputType->size() != 2 || + inputType->childAt(1)->kind() != TypeKind::BIGINT) { + return std::nullopt; + } + auto maybeInputKind = hashAggrJitValueKind(inputType->childAt(0)->kind()); + if (!maybeInputKind.has_value() || + *maybeInputKind != jit::HashAggrJitValueKind::Double) { + return std::nullopt; + } + inputKind = *maybeInputKind; + } else if (inputType->isDecimal() || inputType->isRow() || !jit::isHashAggrJitSupportedType(inputType->kind())) { return std::nullopt; + } else { + auto maybeInputKind = hashAggrJitValueKind(inputType->kind()); + if (!maybeInputKind.has_value()) { + return std::nullopt; + } + inputKind = *maybeInputKind; } - auto maybeInputKind = hashAggrJitValueKind(inputType->kind()); - if (!maybeInputKind.has_value()) { - return std::nullopt; - } - inputKind = *maybeInputKind; } jit::HashAggrJitValueKind accumulatorKind = inputKind; @@ -157,21 +177,24 @@ std::optional makeHashAggrJitSlot( accumulatorKind = jit::HashAggrJitValueKind::Int64; break; case jit::HashAggrJitKind::Sum: + if (decimal) { + accumulatorKind = jit::HashAggrJitValueKind::Int128; + break; + } accumulatorKind = (inputKind == jit::HashAggrJitValueKind::Float || inputKind == jit::HashAggrJitValueKind::Double) ? jit::HashAggrJitValueKind::Double : jit::HashAggrJitValueKind::Int64; break; case jit::HashAggrJitKind::Avg: + if (decimal) { + accumulatorKind = jit::HashAggrJitValueKind::Int128; + break; + } accumulatorKind = jit::HashAggrJitValueKind::Double; break; case jit::HashAggrJitKind::Min: case jit::HashAggrJitKind::Max: - // Keep floating-point min/max on the existing implementation to preserve - // Velox/Spark NaN ordering edge cases exactly. - if (!isIntegralJitKind(inputKind)) { - return std::nullopt; - } accumulatorKind = inputKind; break; } @@ -184,15 +207,19 @@ std::optional makeHashAggrJitSlot( aggregate.function->accumulatorOffset(), aggregate.function->accumulatorNullByte(), aggregate.function->accumulatorNullMask(), - countStar}; + countStar, + !isRawInput, + decimal}; } std::string hashAggrJitSignature(const jit::HashAggrJitSlot& slot) { return fmt::format( - "{}_{}_{}", + "{}_{}_{}_{}_{}", static_cast(slot.kind), jit::hashAggrJitValueKindName(slot.inputKind), - jit::hashAggrJitValueKindName(slot.accumulatorKind)); + jit::hashAggrJitValueKindName(slot.accumulatorKind), + slot.mergeInput, + slot.decimal); } #endif @@ -921,8 +948,7 @@ const SelectivityVector& GroupingSet::getSelectivityVector( #ifdef ENABLE_BOLT_JIT void GroupingSet::maybeCreateHashAggrJitPlan() { hashAggrJitChunks_.clear(); - if (!queryConfig_.enableHashAggrJit() || !isRawInput_ || isGlobal_ || - ignoreNullKeys_) { + if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { return; } @@ -933,7 +959,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto minChunkWidth = std::max(minFuseWidth, compileMinCount); std::unordered_map> groups; for (auto i = 0; i < aggregates_.size(); ++i) { - auto slot = makeHashAggrJitSlot(i, aggregates_[i]); + auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_); if (!slot.has_value()) { continue; } @@ -1043,6 +1069,49 @@ void GroupingSet::runHashAggrJitChunks( } } } + +void GroupingSet::runHashAggrJitExtractChunks( + folly::Range groups, + const RowVectorPtr& result, + int32_t aggregateOutputOffset, + std::vector& jitExtracted) { + if (hashAggrJitChunks_.empty() || groups.empty() || hasSpilled() || + supportRowBasedOutput_) { + return; + } + + jitExtracted.assign(aggregates_.size(), 0); + for (auto& chunk : hashAggrJitChunks_) { + if (!chunk.canExtract(isPartial_)) { + continue; + } + const auto numSlots = chunk.slots().size(); + hashAggrJitResultPtrs_.assign(numSlots, nullptr); + bool canRunChunk = true; + for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { + const auto& slot = chunk.slots()[slotIndex]; + const auto& aggregate = aggregates_[slot.aggregateIndex]; + if (aggregate.distinct || aggregate.mask.has_value() || + !aggregate.sortingKeys.empty()) { + canRunChunk = false; + break; + } + auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); + if (aggregateVector->encoding() != VectorEncoding::Simple::FLAT) { + canRunChunk = false; + break; + } + hashAggrJitResultPtrs_[slotIndex] = reinterpret_cast(aggregateVector.get()); + } + if (!canRunChunk) { + continue; + } + chunk.extract(groups.data(), groups.size(), hashAggrJitResultPtrs_.data()); + for (const auto& slot : chunk.slots()) { + jitExtracted[slot.aggregateIndex] = 1; + } + } +} #endif bool GroupingSet::getOutput( @@ -1155,7 +1224,14 @@ void GroupingSet::extractGroups( rows.extractColumn(groups.data(), groups.size(), i, keyVector); } } + std::vector jitExtracted; +#ifdef ENABLE_BOLT_JIT + runHashAggrJitExtractChunks(groups, result, totalKeys, jitExtracted); +#endif for (int32_t i = 0; i < aggregates_.size(); ++i) { + if (!jitExtracted.empty() && jitExtracted[i]) { + continue; + } if (!aggregates_[i].sortingKeys.empty()) { continue; } diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 9689c63de..0781910b5 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -294,6 +294,11 @@ class GroupingSet { const RowVectorPtr& input, bool mayPushdown, std::vector& jitExecuted); + void runHashAggrJitExtractChunks( + folly::Range groups, + const RowVectorPtr& result, + int32_t aggregateOutputOffset, + std::vector& jitExtracted); #endif // Checks if input will fit in the existing memory and increases reservation @@ -462,6 +467,7 @@ class GroupingSet { std::vector hashAggrJitInputVectors_; std::vector hashAggrJitDecodedPtrs_; std::vector hashAggrJitNewGroups_; + std::vector hashAggrJitResultPtrs_; #endif // True if any aggregate accumulator allocates memory outside RowContainer's diff --git a/bolt/exec/RowContainer.cpp b/bolt/exec/RowContainer.cpp index 9412a2dbd..685a3872d 100644 --- a/bolt/exec/RowContainer.cpp +++ b/bolt/exec/RowContainer.cpp @@ -37,6 +37,7 @@ #include "bolt/type/StringView.h" #include "bolt/type/Timestamp.h" #include "bolt/vector/DecodedVector.h" +#include "bolt/vector/FlatVector.h" #include "bolt/common/memory/ByteStream.h" #include "bolt/common/memory/RawVector.h" @@ -1746,6 +1747,24 @@ __attribute__((__visibility__("default"))) double jit_GetDecodedValueDouble( return reinterpret_cast(vec) ->valueAt(index); } + +__attribute__((__visibility__("default"))) double +jit_GetDecodedRowFieldDouble(char* vec, int32_t index, int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + +__attribute__((__visibility__("default"))) int64_t jit_GetDecodedRowFieldI64( + char* vec, + int32_t index, + int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} // get decoded value string __attribute__((__visibility__("default"))) const char* jit_GetDecodedValueStringView(char* vec, int32_t index) { @@ -1774,6 +1793,72 @@ __attribute__((__visibility__("default"))) int8_t jit_GetDecodedIsNull( index); } +__attribute__((__visibility__("default"))) void jit_HashAggrResizeVector( + char* vector, + int32_t size) { + reinterpret_cast(vector)->resize(size); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI8( + char* vector, + int32_t row, + int8_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI16( + char* vector, + int32_t row, + int16_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI32( + char* vector, + int32_t row, + int32_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI64( + char* vector, + int32_t row, + int64_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatFloat( + char* vector, + int32_t row, + float value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatDouble( + char* vector, + int32_t row, + double value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + __attribute__((__visibility__("default"))) int8_t jit_ComplexTypeRowEqVectors( const char* row, int32_t offset, diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 598f74652..0b91ccd02 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -6,6 +6,8 @@ #include #include +#include + #include "bolt/core/QueryConfig.h" #include "bolt/exec/tests/utils/AssertQueryBuilder.h" #include "bolt/exec/tests/utils/PlanBuilder.h" @@ -51,6 +53,34 @@ class HashAggrJitBenchmark : public VectorTestBase { addCase(name + "_avg", rows, avgs); addCase(name + "_min", rows, mins); addCase(name + "_count", rows, counts); + addCase(name + "_merge_sum", rows, sums, true); + addCase(name + "_merge_avg", rows, avgs, true); + addCase(name + "_merge_min", rows, mins, true); + addCase(name + "_merge_count", rows, counts, true); + } + + void addDecimalBenchmark(const std::string& name, int32_t width) { + auto rows = makeDecimalRows(width); + std::vector sums; + std::vector avgs; + for (auto i = 0; i < width; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + } + addCase(name + "_decimal_sum", rows, sums); + addCase(name + "_decimal_avg", rows, avgs); + } + + void addFloatingPointMinMaxBenchmark(const std::string& name, int32_t width) { + auto rows = makeDoubleRows(width); + std::vector mins; + std::vector maxs; + for (auto i = 0; i < width; ++i) { + mins.push_back(fmt::format("min(c{})", i + 1)); + maxs.push_back(fmt::format("max(c{})", i + 1)); + } + addCase(name + "_double_min", rows, mins); + addCase(name + "_double_max", rows, maxs); } private: @@ -82,13 +112,61 @@ class HashAggrJitBenchmark : public VectorTestBase { return rows; } + std::vector makeDecimalRows(int32_t width) { + std::vector names; + std::vector children; + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [](vector_size_t row) { return row % FLAGS_hashaggr_jit_benchmark_groups; })); + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [column](vector_size_t row) { + return static_cast((row + 13 * column) % 100000); + }, + nullptr, + DECIMAL(12, 2))); + } + auto batch = makeRowVector(names, children); + return std::vector(FLAGS_hashaggr_jit_benchmark_batches, batch); + } + + std::vector makeDoubleRows(int32_t width) { + std::vector names; + std::vector children; + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [](vector_size_t row) { return row % FLAGS_hashaggr_jit_benchmark_groups; })); + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [column](vector_size_t row) { + if ((row + column) % 997 == 0) { + return std::numeric_limits::quiet_NaN(); + } + return static_cast((row + 17 * column) & 0xffff); + })); + } + auto batch = makeRowVector(names, children); + return std::vector(FLAGS_hashaggr_jit_benchmark_batches, batch); + } + std::shared_ptr makePlan( const std::vector& rows, - const std::vector& aggregates) { - return exec::test::PlanBuilder() - .values(rows) - .singleAggregation({"c0"}, aggregates) - .planNode(); + const std::vector& aggregates, + bool partialFinal = false) { + exec::test::PlanBuilder builder; + builder.values(rows); + if (partialFinal) { + builder.partialAggregation({"c0"}, aggregates).finalAggregation(); + } else { + builder.singleAggregation({"c0"}, aggregates); + } + return builder.planNode(); } void run(const std::shared_ptr& plan, bool enableJit) { @@ -103,9 +181,10 @@ class HashAggrJitBenchmark : public VectorTestBase { void addCase( const std::string& name, const std::vector& rows, - const std::vector& aggregates) { + const std::vector& aggregates, + bool partialFinal = false) { auto testCase = std::make_unique(); - testCase->plan = makePlan(rows, aggregates); + testCase->plan = makePlan(rows, aggregates, partialFinal); // Warm up both paths so the benchmark compares steady-state execution and // doesn't charge one-time plan setup / JIT compilation to the first sample. run(testCase->plan, false); @@ -140,6 +219,8 @@ int main(int argc, char** argv) { benchmark.addBenchmark("width8", 8); benchmark.addBenchmark("width16", 16); benchmark.addBenchmark("width32", 32); + benchmark.addDecimalBenchmark("width8", 8); + benchmark.addFloatingPointMinMaxBenchmark("width8", 8); folly::runBenchmarks(); return 0; diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 488f06ae8..d472bd696 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -30,9 +30,12 @@ #include "bolt/exec/tests/utils/AssertQueryBuilder.h" #include "bolt/exec/tests/utils/PlanBuilder.h" +#include "bolt/exec/tests/utils/QueryAssertions.h" #include "bolt/functions/lib/aggregates/tests/SumTestBase.h" #include "bolt/functions/sparksql/aggregates/Register.h" +#include + using bytedance::bolt::exec::test::PlanBuilder; using namespace bytedance::bolt::exec::test; using namespace bytedance::bolt::functions::aggregate::test; @@ -131,6 +134,67 @@ TEST_F(SumAggregationTest, hookLimits) { testHookLimits(); } +TEST_F(SumAggregationTest, hashAggrJitDecimalSumAndFloatingMinMax) { + auto input = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 8; }), + makeFlatVector( + 256, [](auto row) { return row * 100; }, nullptr, DECIMAL(12, 2)), + makeFlatVector(256, [](auto row) { + return row % 31 == 0 ? std::numeric_limits::quiet_NaN() + : static_cast(row); + }), + makeFlatVector(256, [](auto row) { + return row % 37 == 0 ? std::numeric_limits::quiet_NaN() + : static_cast(1000 - row); + })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .singleAggregation( + {"c0"}, + {"spark_sum(c1)", "min(c2)", "max(c3)"}) + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + +TEST_F(SumAggregationTest, hashAggrJitMergeAndExtract) { + auto input = makeRowVector( + {makeFlatVector(512, [](auto row) { return row % 16; }), + makeFlatVector(512, [](auto row) { return row; }), + makeFlatVector(512, [](auto row) { return 1000 - row; }), + makeFlatVector( + 512, + [](auto row) { return row; }, + [](auto row) { return row % 7 == 0; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .partialAggregation( + {"c0"}, + {"spark_sum(c1)", "spark_avg(c1)", "min(c2)", "count(c3)"}) + .finalAggregation() + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + TEST_F(SumAggregationTest, decimalSum) { std::vector> shortDecimalRawVector; std::vector> longDecimalRawVector; diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 56c5e5130..91a1c1720 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -6,10 +6,95 @@ #include #include +#include #include #include "bolt/jit/ThrustJITv2.h" +extern "C" { + +namespace { + +struct JitDecimalSumState { + bytedance::bolt::int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; +}; + +struct JitDecimalAvgState { + bytedance::bolt::int128_t sum{0}; + int64_t count{0}; + int64_t overflow{0}; +}; + +int64_t jitHashAggrAddWithOverflow( + bytedance::bolt::int128_t left, + bytedance::bolt::int128_t right, + bytedance::bolt::int128_t& result) { + result = left + right; + if (left > 0 && right > 0 && result < 0) { + return 1; + } + if (left < 0 && right < 0 && result >= 0) { + return -1; + } + return 0; +} + +} // namespace + +__attribute__((__visibility__("default"))) void jit_HashAggrInitDecimalSum( + char* group, + int32_t offset) { + new (group + offset) JitDecimalSumState(); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrInitDecimalAvg( + char* group, + int32_t offset) { + new (group + offset) JitDecimalAvgState(); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalSumI64( + char* group, + int32_t offset, + int64_t value) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow( + state->sum, static_cast(value), state->sum); + state->isEmpty = false; +} + +__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalSumI128( + char* group, + int32_t offset, + bytedance::bolt::int128_t value) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); + state->isEmpty = false; +} + +__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalAvgI64( + char* group, + int32_t offset, + int64_t value) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow( + state->sum, static_cast(value), state->sum); + ++state->count; +} + +__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalAvgI128( + char* group, + int32_t offset, + bytedance::bolt::int128_t value) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); + ++state->count; +} + +} // extern "C" + namespace bytedance::bolt::jit { namespace { @@ -28,19 +113,43 @@ void ensureBuiltinDeclarations(llvm::Module& module) { auto* i16Ty = llvm::Type::getInt16Ty(context); auto* i32Ty = llvm::Type::getInt32Ty(context); auto* i64Ty = llvm::Type::getInt64Ty(context); + auto* i128Ty = llvm::Type::getInt128Ty(context); auto* floatTy = llvm::Type::getFloatTy(context); auto* doubleTy = llvm::Type::getDoubleTy(context); + auto* voidTy = llvm::Type::getVoidTy(context); auto* i8PtrTy = llvm::PointerType::get(context, 0); declareFunction(module, "jit_GetDecodedValueI8", i8Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_GetDecodedValueI16", i16Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_GetDecodedValueI32", i32Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_GetDecodedValueI64", i64Ty, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_GetDecodedValueI128", i128Ty, {i8PtrTy, i32Ty}); declareFunction( module, "jit_GetDecodedValueFloat", floatTy, {i8PtrTy, i32Ty}); declareFunction( module, "jit_GetDecodedValueDouble", doubleTy, {i8PtrTy, i32Ty}); + declareFunction( + module, "jit_GetDecodedRowFieldDouble", doubleTy, {i8PtrTy, i32Ty, i32Ty}); + declareFunction( + module, "jit_GetDecodedRowFieldI64", i64Ty, {i8PtrTy, i32Ty, i32Ty}); declareFunction(module, "jit_GetDecodedIsNull", i8Ty, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_HashAggrInitDecimalSum", voidTy, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_HashAggrInitDecimalAvg", voidTy, {i8PtrTy, i32Ty}); + declareFunction( + module, "jit_HashAggrUpdateDecimalSumI64", voidTy, {i8PtrTy, i32Ty, i64Ty}); + declareFunction( + module, "jit_HashAggrUpdateDecimalSumI128", voidTy, {i8PtrTy, i32Ty, i128Ty}); + declareFunction( + module, "jit_HashAggrUpdateDecimalAvgI64", voidTy, {i8PtrTy, i32Ty, i64Ty}); + declareFunction( + module, "jit_HashAggrUpdateDecimalAvgI128", voidTy, {i8PtrTy, i32Ty, i128Ty}); + declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); + declareFunction(module, "jit_HashAggrSetFlatI8", voidTy, {i8PtrTy, i32Ty, i8Ty, i8Ty}); + declareFunction(module, "jit_HashAggrSetFlatI16", voidTy, {i8PtrTy, i32Ty, i16Ty, i8Ty}); + declareFunction(module, "jit_HashAggrSetFlatI32", voidTy, {i8PtrTy, i32Ty, i32Ty, i8Ty}); + declareFunction(module, "jit_HashAggrSetFlatI64", voidTy, {i8PtrTy, i32Ty, i64Ty, i8Ty}); + declareFunction(module, "jit_HashAggrSetFlatFloat", voidTy, {i8PtrTy, i32Ty, floatTy, i8Ty}); + declareFunction(module, "jit_HashAggrSetFlatDouble", voidTy, {i8PtrTy, i32Ty, doubleTy, i8Ty}); } llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { @@ -53,6 +162,8 @@ llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { return builder.getInt32Ty(); case HashAggrJitValueKind::Int64: return builder.getInt64Ty(); + case HashAggrJitValueKind::Int128: + return builder.getInt128Ty(); case HashAggrJitValueKind::Float: return builder.getFloatTy(); case HashAggrJitValueKind::Double: @@ -71,6 +182,8 @@ std::string decodedValueFunction(HashAggrJitValueKind kind) { return "jit_GetDecodedValueI32"; case HashAggrJitValueKind::Int64: return "jit_GetDecodedValueI64"; + case HashAggrJitValueKind::Int128: + return "jit_GetDecodedValueI128"; case HashAggrJitValueKind::Float: return "jit_GetDecodedValueFloat"; case HashAggrJitValueKind::Double: @@ -195,11 +308,28 @@ void genCountUpdate( void genNonNullUpdate( llvm::IRBuilder<>& builder, + llvm::Module& module, llvm::Value* group, llvm::Value* rawValue, const HashAggrJitSlot& slot) { auto* accType = llvmType(builder, slot.accumulatorKind); auto* value = castValue(builder, rawValue, slot.inputKind, slot.accumulatorKind); + if (slot.decimal) { + clearAccumulatorNull(builder, group, slot); + const auto helper = slot.kind == HashAggrJitKind::Sum + ? (slot.inputKind == HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalSumI128" + : "jit_HashAggrUpdateDecimalSumI64") + : (slot.inputKind == HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalAvgI128" + : "jit_HashAggrUpdateDecimalAvgI64"); + builder.CreateCall( + module.getFunction(helper), + {group, + builder.getInt32(slot.offset), + slot.inputKind == HashAggrJitValueKind::Int128 ? value : rawValue}); + return; + } switch (slot.kind) { case HashAggrJitKind::Sum: { clearAccumulatorNull(builder, group, slot); @@ -228,9 +358,26 @@ void genNonNullUpdate( case HashAggrJitKind::Max: { auto* oldValue = loadValue(builder, group, accType, slot.offset); auto* nullState = isAccumulatorNull(builder, group, slot); - auto* better = slot.kind == HashAggrJitKind::Min - ? builder.CreateICmpSLT(value, oldValue) - : builder.CreateICmpSGT(value, oldValue); + llvm::Value* better; + if (isFloatKind(slot.accumulatorKind)) { + auto* oldIsNan = builder.CreateFCmpUNO(oldValue, oldValue); + auto* valueIsNan = builder.CreateFCmpUNO(value, value); + if (slot.kind == HashAggrJitKind::Min) { + better = builder.CreateOr( + builder.CreateAnd(oldIsNan, builder.CreateNot(valueIsNan)), + builder.CreateAnd( + builder.CreateNot(valueIsNan), + builder.CreateFCmpOGT(oldValue, value))); + } else { + better = builder.CreateAnd( + builder.CreateNot(oldIsNan), + builder.CreateOr(valueIsNan, builder.CreateFCmpOLT(oldValue, value))); + } + } else { + better = slot.kind == HashAggrJitKind::Min + ? builder.CreateICmpSLT(value, oldValue) + : builder.CreateICmpSGT(value, oldValue); + } auto* shouldStore = builder.CreateOr(nullState, better); auto* selected = builder.CreateSelect(shouldStore, value, oldValue); storeValue(builder, group, accType, slot.offset, selected); @@ -238,11 +385,53 @@ void genNonNullUpdate( break; } case HashAggrJitKind::Count: - genCountUpdate(builder, group, slot); + if (slot.mergeInput) { + auto* state = loadValue(builder, group, builder.getInt64Ty(), slot.offset); + storeValue( + builder, + group, + builder.getInt64Ty(), + slot.offset, + builder.CreateAdd(state, castValue(builder, rawValue, slot.inputKind, HashAggrJitValueKind::Int64))); + } else { + genCountUpdate(builder, group, slot); + } break; } } +void genAvgMergeUpdate( + llvm::IRBuilder<>& builder, + llvm::Module& module, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot) { + clearAccumulatorNull(builder, group, slot); + auto* sum = builder.CreateCall( + module.getFunction("jit_GetDecodedRowFieldDouble"), + {decoded, row, builder.getInt32(0)}); + auto* count = builder.CreateCall( + module.getFunction("jit_GetDecodedRowFieldI64"), + {decoded, row, builder.getInt32(1)}); + + auto* oldSum = loadValue(builder, group, builder.getDoubleTy(), slot.offset); + storeValue( + builder, + group, + builder.getDoubleTy(), + slot.offset, + builder.CreateFAdd(oldSum, sum)); + + auto* oldCount = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); + storeValue( + builder, + group, + builder.getInt64Ty(), + slot.offset + 8, + builder.CreateAdd(oldCount, count)); +} + bool genAddDenseIR( llvm::Module& module, const std::string& fn, @@ -253,6 +442,7 @@ bool genInitIR( llvm::Module& module, const std::string& fn, const std::vector& slots) { + ensureBuiltinDeclarations(module); auto& context = module.getContext(); llvm::IRBuilder<> builder(context); auto* voidTy = builder.getVoidTy(); @@ -285,6 +475,14 @@ bool genInitIR( if (slot.kind != HashAggrJitKind::Count) { setAccumulatorNull(builder, group, slot); } + if (slot.decimal) { + builder.CreateCall( + module.getFunction( + slot.kind == HashAggrJitKind::Sum ? "jit_HashAggrInitDecimalSum" + : "jit_HashAggrInitDecimalAvg"), + {group, builder.getInt32(slot.offset)}); + continue; + } auto* accType = llvmType(builder, slot.accumulatorKind); if (isFloatKind(slot.accumulatorKind)) { storeValue( @@ -376,11 +574,13 @@ bool genAddDenseIR( } builder.SetInsertPoint(updateBlock); - if (slot.kind == HashAggrJitKind::Count) { + if (slot.kind == HashAggrJitKind::Count && !slot.mergeInput) { genCountUpdate(builder, group, slot); + } else if (slot.kind == HashAggrJitKind::Avg && slot.mergeInput) { + genAvgMergeUpdate(builder, module, group, decoded, row, slot); } else { auto* value = loadDecodedValue(builder, module, decoded, row, slot); - genNonNullUpdate(builder, group, value, slot); + genNonNullUpdate(builder, module, group, value, slot); } builder.CreateBr(nextBlock); builder.SetInsertPoint(nextBlock); @@ -396,6 +596,107 @@ bool genAddDenseIR( return llvm::verifyFunction(*func, &llvm::errs()); } +std::string setFlatValueFunction(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Int8: + return "jit_HashAggrSetFlatI8"; + case HashAggrJitValueKind::Int16: + return "jit_HashAggrSetFlatI16"; + case HashAggrJitValueKind::Int32: + return "jit_HashAggrSetFlatI32"; + case HashAggrJitValueKind::Int64: + return "jit_HashAggrSetFlatI64"; + case HashAggrJitValueKind::Float: + return "jit_HashAggrSetFlatFloat"; + case HashAggrJitValueKind::Double: + return "jit_HashAggrSetFlatDouble"; + case HashAggrJitValueKind::Int128: + return ""; + } + return ""; +} + +bool genExtractIR( + llvm::Module& module, + const std::string& fn, + const std::vector& slots) { + ensureBuiltinDeclarations(module); + auto& context = module.getContext(); + llvm::IRBuilder<> builder(context); + auto* voidTy = builder.getVoidTy(); + auto* i8PtrTy = llvm::PointerType::get(context, 0); + auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); + auto* i32Ty = builder.getInt32Ty(); + auto* funcTy = llvm::FunctionType::get(voidTy, {i8PtrPtrTy, i32Ty, i8PtrPtrTy}, false); + auto* func = llvm::Function::Create(funcTy, llvm::Function::ExternalLinkage, fn, module); + auto argIt = func->arg_begin(); + llvm::Value* groups = &*argIt++; + groups->setName("groups"); + llvm::Value* numGroups = &*argIt++; + numGroups->setName("num_groups"); + llvm::Value* resultVectors = &*argIt++; + resultVectors->setName("result_vectors"); + + auto* entry = llvm::BasicBlock::Create(context, "entry", func); + auto* loop = llvm::BasicBlock::Create(context, "loop", func); + auto* end = llvm::BasicBlock::Create(context, "end", func); + builder.SetInsertPoint(entry); + for (auto i = 0; i < slots.size(); ++i) { + if (slots[i].decimal || slots[i].accumulatorKind == HashAggrJitValueKind::Int128) { + continue; + } + auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); + auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); + builder.CreateCall(module.getFunction("jit_HashAggrResizeVector"), {vector, numGroups}); + } + builder.CreateCondBr(builder.CreateICmpSLE(numGroups, builder.getInt32(0)), end, loop); + + builder.SetInsertPoint(loop); + auto* row = builder.CreatePHI(i32Ty, 2, "row"); + row->addIncoming(builder.getInt32(0), entry); + auto* groupAddr = builder.CreateInBoundsGEP(i8PtrTy, groups, row); + auto* group = builder.CreateLoad(i8PtrTy, groupAddr); + + for (auto i = 0; i < slots.size(); ++i) { + const auto& slot = slots[i]; + if (slot.decimal || slot.accumulatorKind == HashAggrJitValueKind::Int128) { + continue; + } + auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); + auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); + HashAggrJitValueKind resultKind = slot.accumulatorKind; + llvm::Value* value = nullptr; + llvm::Value* isNull = nullptr; + if (slot.kind == HashAggrJitKind::Avg) { + auto* sum = loadValue(builder, group, llvmType(builder, slot.accumulatorKind), slot.offset); + auto* count = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); + auto* countIsZero = builder.CreateICmpEQ(count, builder.getInt64(0)); + auto* divisor = builder.CreateSIToFP(count, llvmType(builder, slot.accumulatorKind)); + value = builder.CreateFDiv(sum, divisor); + isNull = builder.CreateZExt(countIsZero, builder.getInt8Ty()); + } else { + value = loadValue(builder, group, llvmType(builder, resultKind), slot.offset); + isNull = slot.kind == HashAggrJitKind::Count + ? builder.getInt8(0) + : builder.CreateZExt(isAccumulatorNull(builder, group, slot), builder.getInt8Ty()); + } + const auto setter = setFlatValueFunction(resultKind); + if (setter.empty()) { + continue; + } + builder.CreateCall(module.getFunction(setter), {vector, row, value, isNull}); + } + + auto* next = builder.CreateAdd(row, builder.getInt32(1)); + row->addIncoming(next, builder.GetInsertBlock()); + builder.CreateCondBr(builder.CreateICmpSLT(next, numGroups), loop, end); + + builder.SetInsertPoint(end); + builder.CreateRetVoid(); + + return llvm::verifyFunction(*func, &llvm::errs()); +} + } // namespace HashAggrJitChunk::HashAggrJitChunk(std::vector slots) @@ -411,6 +712,8 @@ std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { return "i32"; case HashAggrJitValueKind::Int64: return "i64"; + case HashAggrJitValueKind::Int128: + return "i128"; case HashAggrJitValueKind::Float: return "f32"; case HashAggrJitValueKind::Double: @@ -440,11 +743,25 @@ std::string HashAggrJitChunk::functionName() const { out << "_" << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) << hashAggrJitValueKindName(slot.accumulatorKind) << "o" << slot.offset << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) - << (slot.countStar ? "s" : "x"); + << (slot.countStar ? "s" : "x") << (slot.mergeInput ? "g" : "r") + << (slot.decimal ? "d" : "n"); } return out.str(); } +bool HashAggrJitChunk::canExtract(bool partialOutput) const { + if (extract_ == nullptr || disabled_) { + return false; + } + for (const auto& slot : slots_) { + if (slot.decimal || slot.accumulatorKind == HashAggrJitValueKind::Int128 || + (partialOutput && slot.kind == HashAggrJitKind::Avg)) { + return false; + } + } + return true; +} + std::string HashAggrJitChunk::initFunctionName() const { return functionName() + "_init"; } @@ -453,6 +770,10 @@ std::string HashAggrJitChunk::addDenseNoNullFunctionName() const { return functionName() + "_add_dense_no_null"; } +std::string HashAggrJitChunk::extractFunctionName() const { + return functionName() + "_extract"; +} + bool HashAggrJitChunk::codegen() { if (addDense_) { return true; @@ -465,11 +786,13 @@ bool HashAggrJitChunk::codegen() { const auto initFn = initFunctionName(); const auto addFn = moduleKey + "_add_dense"; const auto addNoNullFn = addDenseNoNullFunctionName(); + const auto extractFn = extractFunctionName(); module_ = jit->CompileModule( [&](llvm::Module& module) { return genInitIR(module, initFn, slots_) || genAddDenseIR(module, addFn, slots_, true) || - genAddDenseIR(module, addNoNullFn, slots_, false); + genAddDenseIR(module, addNoNullFn, slots_, false) || + genExtractIR(module, extractFn, slots_); }, moduleKey); if (!module_) { @@ -480,7 +803,9 @@ bool HashAggrJitChunk::codegen() { addDense_ = reinterpret_cast(module_->getFuncPtr(addFn)); addDenseNoNull_ = reinterpret_cast( module_->getFuncPtr(addNoNullFn)); - if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr) { + extract_ = reinterpret_cast(module_->getFuncPtr(extractFn)); + if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr || + extract_ == nullptr) { disabled_ = true; return false; } diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index b1f9e39be..5af411a8a 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -25,6 +25,7 @@ enum class HashAggrJitValueKind : uint8_t { Int16, Int32, Int64, + Int128, Float, Double, }; @@ -38,10 +39,13 @@ struct HashAggrJitSlot { int32_t nullByte; uint8_t nullMask; bool countStar{false}; + bool mergeInput{false}; + bool decimal{false}; }; using HashAggrJitAddDenseFunc = void (*)(char** groups, int32_t numRows, char** decodedInputs); using HashAggrJitInitFunc = void (*)(char** newGroups, int32_t numNewGroups); +using HashAggrJitExtractFunc = void (*)(char** groups, int32_t numGroups, char** resultVectors); class HashAggrJitChunk { public: @@ -53,6 +57,8 @@ class HashAggrJitChunk { return addDense_ != nullptr && !disabled_; } + bool canExtract(bool partialOutput) const; + void disable() { disabled_ = true; } @@ -73,6 +79,10 @@ class HashAggrJitChunk { addDense_(groups, numRows, decodedInputs); } + void extract(char** groups, int32_t numGroups, char** resultVectors) const { + extract_(groups, numGroups, resultVectors); + } + const std::vector& slots() const { return slots_; } @@ -80,6 +90,7 @@ class HashAggrJitChunk { std::string functionName() const; std::string initFunctionName() const; std::string addDenseNoNullFunctionName() const; + std::string extractFunctionName() const; private: std::vector slots_; @@ -87,6 +98,7 @@ class HashAggrJitChunk { HashAggrJitInitFunc init_{nullptr}; HashAggrJitAddDenseFunc addDense_{nullptr}; HashAggrJitAddDenseFunc addDenseNoNull_{nullptr}; + HashAggrJitExtractFunc extract_{nullptr}; bool disabled_{false}; }; From ac4b58223bd893424ccda2b5446f71427041f2b6 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 1 Jun 2026 10:36:10 +0800 Subject: [PATCH 04/98] commit again --- bolt/exec/GroupingSet.cpp | 10 ++- bolt/exec/RowContainer.cpp | 20 +++++ bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 78 ++++++++++++++++--- .../aggregates/tests/SumAggregationTest.cpp | 23 ++++++ bolt/jit/aggregation/HashAggrJit.cpp | 36 +++++++-- bolt/jit/aggregation/HashAggrJit.h | 7 +- 6 files changed, 150 insertions(+), 24 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index e59965832..7721ed171 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -977,7 +977,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { if (chunkSlots.size() < minChunkWidth) { continue; } - jit::HashAggrJitChunk chunk(std::move(chunkSlots)); + jit::HashAggrJitChunk chunk(std::move(chunkSlots), isPartial_); if (chunk.codegen()) { hashAggrJitChunks_.push_back(std::move(chunk)); } @@ -1082,7 +1082,7 @@ void GroupingSet::runHashAggrJitExtractChunks( jitExtracted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { - if (!chunk.canExtract(isPartial_)) { + if (!chunk.canExtract()) { continue; } const auto numSlots = chunk.slots().size(); @@ -1097,7 +1097,11 @@ void GroupingSet::runHashAggrJitExtractChunks( break; } auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); - if (aggregateVector->encoding() != VectorEncoding::Simple::FLAT) { + const auto expectedEncoding = + (isPartial_ && slot.kind == jit::HashAggrJitKind::Avg) + ? VectorEncoding::Simple::ROW + : VectorEncoding::Simple::FLAT; + if (aggregateVector->encoding() != expectedEncoding) { canRunChunk = false; break; } diff --git a/bolt/exec/RowContainer.cpp b/bolt/exec/RowContainer.cpp index 685a3872d..f887a4479 100644 --- a/bolt/exec/RowContainer.cpp +++ b/bolt/exec/RowContainer.cpp @@ -1859,6 +1859,26 @@ __attribute__((__visibility__("default"))) void jit_HashAggrSetFlatDouble( isNull ? flat->setNull(row, true) : flat->set(row, value); } +__attribute__((__visibility__("default"))) void jit_HashAggrSetPartialAvgDouble( + char* vector, + int32_t row, + double sum, + int64_t count, + int8_t isNull) { + auto* rowVector = + reinterpret_cast(vector) + ->as(); + auto* sumVector = rowVector->childAt(0)->asFlatVector(); + auto* countVector = rowVector->childAt(1)->asFlatVector(); + if (isNull) { + rowVector->setNull(row, true); + return; + } + rowVector->setNull(row, false); + sumVector->set(row, sum); + countVector->set(row, count); +} + __attribute__((__visibility__("default"))) int8_t jit_ComplexTypeRowEqVectors( const char* row, int32_t offset, diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 0b91ccd02..f08032696 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -30,6 +30,12 @@ struct HashAggrJitBenchmarkCase { std::shared_ptr plan; }; +enum class AggregationPlanKind { + Single, + Partial, + PartialFinal, +}; + class HashAggrJitBenchmark : public VectorTestBase { public: void addBenchmark(const std::string& name, int32_t width) { @@ -53,10 +59,10 @@ class HashAggrJitBenchmark : public VectorTestBase { addCase(name + "_avg", rows, avgs); addCase(name + "_min", rows, mins); addCase(name + "_count", rows, counts); - addCase(name + "_merge_sum", rows, sums, true); - addCase(name + "_merge_avg", rows, avgs, true); - addCase(name + "_merge_min", rows, mins, true); - addCase(name + "_merge_count", rows, counts, true); + addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); } void addDecimalBenchmark(const std::string& name, int32_t width) { @@ -83,6 +89,20 @@ class HashAggrJitBenchmark : public VectorTestBase { addCase(name + "_double_max", rows, maxs); } + void addHighCardinalityExtractBenchmark(const std::string& name, int32_t width) { + auto rows = makeHighCardinalityRows(width); + std::vector avgs; + std::vector sums; + avgs.reserve(width); + sums.reserve(width); + for (auto i = 0; i < width; ++i) { + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + } + addCase(name + "_partial_avg_extract", rows, avgs, AggregationPlanKind::Partial); + addCase(name + "_partial_sum_extract", rows, sums, AggregationPlanKind::Partial); + } + private: std::vector makeRows(int32_t width) { std::vector names; @@ -155,16 +175,51 @@ class HashAggrJitBenchmark : public VectorTestBase { return std::vector(FLAGS_hashaggr_jit_benchmark_batches, batch); } + std::vector makeHighCardinalityRows(int32_t width) { + std::vector rows; + rows.reserve(FLAGS_hashaggr_jit_benchmark_batches); + for (auto batchIndex = 0; batchIndex < FLAGS_hashaggr_jit_benchmark_batches; ++batchIndex) { + std::vector names; + std::vector children; + names.reserve(width + 1); + children.reserve(width + 1); + names.push_back("c0"); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [batchIndex](vector_size_t row) { + return static_cast(batchIndex) * + FLAGS_hashaggr_jit_benchmark_batch_size + + row; + })); + for (auto column = 0; column < width; ++column) { + names.push_back(fmt::format("c{}", column + 1)); + children.push_back(makeFlatVector( + FLAGS_hashaggr_jit_benchmark_batch_size, + [batchIndex, column](vector_size_t row) { + return static_cast(batchIndex * 97 + column * 17 + row); + })); + } + rows.push_back(makeRowVector(names, children)); + } + return rows; + } + std::shared_ptr makePlan( const std::vector& rows, const std::vector& aggregates, - bool partialFinal = false) { + AggregationPlanKind planKind = AggregationPlanKind::Single) { exec::test::PlanBuilder builder; builder.values(rows); - if (partialFinal) { - builder.partialAggregation({"c0"}, aggregates).finalAggregation(); - } else { - builder.singleAggregation({"c0"}, aggregates); + switch (planKind) { + case AggregationPlanKind::Single: + builder.singleAggregation({"c0"}, aggregates); + break; + case AggregationPlanKind::Partial: + builder.partialAggregation({"c0"}, aggregates); + break; + case AggregationPlanKind::PartialFinal: + builder.partialAggregation({"c0"}, aggregates).finalAggregation(); + break; } return builder.planNode(); } @@ -182,9 +237,9 @@ class HashAggrJitBenchmark : public VectorTestBase { const std::string& name, const std::vector& rows, const std::vector& aggregates, - bool partialFinal = false) { + AggregationPlanKind planKind = AggregationPlanKind::Single) { auto testCase = std::make_unique(); - testCase->plan = makePlan(rows, aggregates, partialFinal); + testCase->plan = makePlan(rows, aggregates, planKind); // Warm up both paths so the benchmark compares steady-state execution and // doesn't charge one-time plan setup / JIT compilation to the first sample. run(testCase->plan, false); @@ -221,6 +276,7 @@ int main(int argc, char** argv) { benchmark.addBenchmark("width32", 32); benchmark.addDecimalBenchmark("width8", 8); benchmark.addFloatingPointMinMaxBenchmark("width8", 8); + benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); folly::runBenchmarks(); return 0; diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index d472bd696..46b3a34d7 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -195,6 +195,29 @@ TEST_F(SumAggregationTest, hashAggrJitMergeAndExtract) { assertEqualResults({noJit}, {jit}); } +TEST_F(SumAggregationTest, hashAggrJitPartialAvgExtractAccumulators) { + auto input = makeRowVector( + {makeFlatVector(2048, [](auto row) { return row; }), + makeFlatVector(2048, [](auto row) { return row * 3; }), + makeFlatVector(2048, [](auto row) { return row * 7; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .partialAggregation( + {"c0"}, {"spark_avg(c1)", "spark_avg(c2)"}) + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + TEST_F(SumAggregationTest, decimalSum) { std::vector> shortDecimalRawVector; std::vector> longDecimalRawVector; diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 91a1c1720..7bca0a24f 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -150,6 +150,11 @@ void ensureBuiltinDeclarations(llvm::Module& module) { declareFunction(module, "jit_HashAggrSetFlatI64", voidTy, {i8PtrTy, i32Ty, i64Ty, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatFloat", voidTy, {i8PtrTy, i32Ty, floatTy, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatDouble", voidTy, {i8PtrTy, i32Ty, doubleTy, i8Ty}); + declareFunction( + module, + "jit_HashAggrSetPartialAvgDouble", + voidTy, + {i8PtrTy, i32Ty, doubleTy, i64Ty, i8Ty}); } llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { @@ -619,7 +624,8 @@ std::string setFlatValueFunction(HashAggrJitValueKind kind) { bool genExtractIR( llvm::Module& module, const std::string& fn, - const std::vector& slots) { + const std::vector& slots, + bool partialOutput) { ensureBuiltinDeclarations(module); auto& context = module.getContext(); llvm::IRBuilder<> builder(context); @@ -667,6 +673,18 @@ bool genExtractIR( HashAggrJitValueKind resultKind = slot.accumulatorKind; llvm::Value* value = nullptr; llvm::Value* isNull = nullptr; + if (partialOutput && slot.kind == HashAggrJitKind::Avg) { + auto* sum = loadValue( + builder, group, llvmType(builder, slot.accumulatorKind), slot.offset); + auto* count = + loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); + auto* isNullValue = builder.CreateZExt( + isAccumulatorNull(builder, group, slot), builder.getInt8Ty()); + builder.CreateCall( + module.getFunction("jit_HashAggrSetPartialAvgDouble"), + {vector, row, sum, count, isNullValue}); + continue; + } if (slot.kind == HashAggrJitKind::Avg) { auto* sum = loadValue(builder, group, llvmType(builder, slot.accumulatorKind), slot.offset); auto* count = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); @@ -699,8 +717,10 @@ bool genExtractIR( } // namespace -HashAggrJitChunk::HashAggrJitChunk(std::vector slots) - : slots_(std::move(slots)) {} +HashAggrJitChunk::HashAggrJitChunk( + std::vector slots, + bool partialOutput) + : slots_(std::move(slots)), partialOutput_(partialOutput) {} std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { switch (kind) { @@ -738,7 +758,8 @@ bool isHashAggrJitSupportedType(TypeKind kind) { std::string HashAggrJitChunk::functionName() const { std::ostringstream out; - out << "jit_hashaggr_v2_n" << slots_.size(); + out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" + << slots_.size(); for (const auto& slot : slots_) { out << "_" << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) << hashAggrJitValueKindName(slot.accumulatorKind) << "o" << slot.offset @@ -749,13 +770,12 @@ std::string HashAggrJitChunk::functionName() const { return out.str(); } -bool HashAggrJitChunk::canExtract(bool partialOutput) const { +bool HashAggrJitChunk::canExtract() const { if (extract_ == nullptr || disabled_) { return false; } for (const auto& slot : slots_) { - if (slot.decimal || slot.accumulatorKind == HashAggrJitValueKind::Int128 || - (partialOutput && slot.kind == HashAggrJitKind::Avg)) { + if (slot.decimal || slot.accumulatorKind == HashAggrJitValueKind::Int128) { return false; } } @@ -792,7 +812,7 @@ bool HashAggrJitChunk::codegen() { return genInitIR(module, initFn, slots_) || genAddDenseIR(module, addFn, slots_, true) || genAddDenseIR(module, addNoNullFn, slots_, false) || - genExtractIR(module, extractFn, slots_); + genExtractIR(module, extractFn, slots_, partialOutput_); }, moduleKey); if (!module_) { diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 5af411a8a..2c5305c25 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -49,7 +49,9 @@ using HashAggrJitExtractFunc = void (*)(char** groups, int32_t numGroups, char** class HashAggrJitChunk { public: - explicit HashAggrJitChunk(std::vector slots); + explicit HashAggrJitChunk( + std::vector slots, + bool partialOutput = false); bool codegen(); @@ -57,7 +59,7 @@ class HashAggrJitChunk { return addDense_ != nullptr && !disabled_; } - bool canExtract(bool partialOutput) const; + bool canExtract() const; void disable() { disabled_ = true; @@ -94,6 +96,7 @@ class HashAggrJitChunk { private: std::vector slots_; + bool partialOutput_{false}; CompiledModuleSP module_; HashAggrJitInitFunc init_{nullptr}; HashAggrJitAddDenseFunc addDense_{nullptr}; From 92937d798881b80d025ebb41981cc3960f16cab8 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 1 Jun 2026 11:42:39 +0800 Subject: [PATCH 05/98] extract function-specific logics from aggregator --- bolt/exec/Aggregate.cpp | 28 +++ bolt/exec/Aggregate.h | 15 ++ bolt/exec/GroupingSet.cpp | 168 +++--------------- .../lib/aggregates/AverageAggregateBase.h | 48 +++++ .../lib/aggregates/SumAggregateBase.h | 43 +++++ .../prestosql/aggregates/CountAggregate.cpp | 36 ++++ .../prestosql/aggregates/MinMaxAggregates.cpp | 49 +++++ bolt/jit/aggregation/HashAggrJit.cpp | 33 ++++ bolt/jit/aggregation/HashAggrJit.h | 24 +++ 9 files changed, 302 insertions(+), 142 deletions(-) diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index c8f586b9a..2a34c1d58 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -321,4 +321,32 @@ void Aggregate::clearInternal() { numNulls_ = 0; } +#ifdef ENABLE_BOLT_JIT +bool Aggregate::supportsHashAggrJit( + const jit::HashAggrJitPlanContext& /*context*/) const { + return false; +} + +std::optional Aggregate::createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& /*context*/) const { + return std::nullopt; +} + +jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( + int32_t aggregateIndex, + const jit::HashAggrJitDescriptor& descriptor) const { + return jit::HashAggrJitSlot{ + aggregateIndex, + descriptor.kind, + descriptor.inputKind, + descriptor.accumulatorKind, + accumulatorOffset(), + accumulatorNullByte(), + accumulatorNullMask(), + descriptor.countStar, + descriptor.mergeInput, + descriptor.decimal}; +} +#endif + } // namespace bytedance::bolt::exec diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index a8422b787..301709f4c 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -39,6 +39,9 @@ #include "bolt/exec/AggregateUtil.h" #include "bolt/expression/FunctionSignature.h" #include "bolt/functions/InlineFlatten.h" +#ifdef ENABLE_BOLT_JIT +#include "bolt/jit/aggregation/HashAggrJit.h" +#endif #include "bolt/vector/BaseVector.h" namespace bytedance::bolt::core { class ExpressionEvaluator; @@ -112,6 +115,18 @@ class Aggregate { return false; } +#ifdef ENABLE_BOLT_JIT + virtual bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const; + + virtual std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const; + + jit::HashAggrJitSlot createHashAggrJitSlot( + int32_t aggregateIndex, + const jit::HashAggrJitDescriptor& descriptor) const; +#endif + void setAllocator(HashStringAllocator* allocator) { setAllocatorInternal(allocator); pool_ = allocator->pool(); diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 7721ed171..1753a2af5 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -62,166 +62,50 @@ bool areAllLazyNotLoaded(const std::vector& vectors) { }); } -#ifdef ENABLE_BOLT_JIT -std::string normalizedAggName(const std::string& name) { - static constexpr std::string_view kSparkPrefix{"spark_"}; - if (name.size() > kSparkPrefix.size() && - name.compare(0, kSparkPrefix.size(), kSparkPrefix) == 0) { - return name.substr(kSparkPrefix.size()); - } - return name; -} - -std::optional hashAggrJitKind(const std::string& name) { - const auto normalized = normalizedAggName(name); - if (normalized == "count") { - return jit::HashAggrJitKind::Count; - } - if (normalized == "sum") { - return jit::HashAggrJitKind::Sum; - } - if (normalized == "min") { - return jit::HashAggrJitKind::Min; - } - if (normalized == "max") { - return jit::HashAggrJitKind::Max; - } - if (normalized == "avg") { - return jit::HashAggrJitKind::Avg; - } - return std::nullopt; -} - -std::optional hashAggrJitValueKind(TypeKind kind) { - switch (kind) { - case TypeKind::TINYINT: - return jit::HashAggrJitValueKind::Int8; - case TypeKind::SMALLINT: - return jit::HashAggrJitValueKind::Int16; - case TypeKind::INTEGER: - return jit::HashAggrJitValueKind::Int32; - case TypeKind::BIGINT: - return jit::HashAggrJitValueKind::Int64; - case TypeKind::HUGEINT: - return jit::HashAggrJitValueKind::Int128; - case TypeKind::REAL: - return jit::HashAggrJitValueKind::Float; - case TypeKind::DOUBLE: - return jit::HashAggrJitValueKind::Double; - default: - return std::nullopt; - } -} - std::optional makeHashAggrJitSlot( int32_t aggregateIndex, const AggregateInfo& aggregate, - bool isRawInput) { + bool isRawInput, + bool isPartialOutput) { if (aggregate.distinct || aggregate.mask.has_value() || !aggregate.sortingKeys.empty()) { return std::nullopt; } - auto kind = hashAggrJitKind(aggregate.name); - if (!kind.has_value()) { + const int32_t inputCount = aggregate.inputs.size(); + if (!(isRawInput && inputCount == 0) && inputCount != 1) { return std::nullopt; } - const bool countStar = *kind == jit::HashAggrJitKind::Count && - aggregate.inputs.empty() && isRawInput; - if (!countStar && aggregate.inputs.size() != 1) { + const auto inputType = + inputCount == 0 ? nullptr + : (isRawInput ? aggregate.rawInputTypes[0] + : aggregate.intermediateType); + const jit::HashAggrJitPlanContext context{ + .isRawInput = isRawInput, + .isPartialOutput = isPartialOutput, + .inputCount = inputCount, + .inputType = inputType}; + if (!aggregate.function->supportsHashAggrJit(context)) { return std::nullopt; } - - jit::HashAggrJitValueKind inputKind = jit::HashAggrJitValueKind::Int64; - bool decimal = false; - if (!countStar) { - const auto& inputType = isRawInput ? aggregate.rawInputTypes[0] - : aggregate.intermediateType; - decimal = isRawInput && inputType->isDecimal() && - (*kind == jit::HashAggrJitKind::Sum || *kind == jit::HashAggrJitKind::Avg); - if (decimal) { - auto maybeInputKind = hashAggrJitValueKind(inputType->kind()); - if (!maybeInputKind.has_value() || - (*maybeInputKind != jit::HashAggrJitValueKind::Int64 && - *maybeInputKind != jit::HashAggrJitValueKind::Int128)) { - return std::nullopt; - } - inputKind = *maybeInputKind; - } else if (!isRawInput && *kind == jit::HashAggrJitKind::Avg) { - if (!inputType->isRow() || inputType->size() != 2 || - inputType->childAt(1)->kind() != TypeKind::BIGINT) { - return std::nullopt; - } - auto maybeInputKind = hashAggrJitValueKind(inputType->childAt(0)->kind()); - if (!maybeInputKind.has_value() || - *maybeInputKind != jit::HashAggrJitValueKind::Double) { - return std::nullopt; - } - inputKind = *maybeInputKind; - } else if (inputType->isDecimal() || inputType->isRow() || - !jit::isHashAggrJitSupportedType(inputType->kind())) { - return std::nullopt; - } else { - auto maybeInputKind = hashAggrJitValueKind(inputType->kind()); - if (!maybeInputKind.has_value()) { - return std::nullopt; - } - inputKind = *maybeInputKind; - } - } - - jit::HashAggrJitValueKind accumulatorKind = inputKind; - switch (*kind) { - case jit::HashAggrJitKind::Count: - accumulatorKind = jit::HashAggrJitValueKind::Int64; - break; - case jit::HashAggrJitKind::Sum: - if (decimal) { - accumulatorKind = jit::HashAggrJitValueKind::Int128; - break; - } - accumulatorKind = (inputKind == jit::HashAggrJitValueKind::Float || - inputKind == jit::HashAggrJitValueKind::Double) - ? jit::HashAggrJitValueKind::Double - : jit::HashAggrJitValueKind::Int64; - break; - case jit::HashAggrJitKind::Avg: - if (decimal) { - accumulatorKind = jit::HashAggrJitValueKind::Int128; - break; - } - accumulatorKind = jit::HashAggrJitValueKind::Double; - break; - case jit::HashAggrJitKind::Min: - case jit::HashAggrJitKind::Max: - accumulatorKind = inputKind; - break; + auto descriptor = aggregate.function->createHashAggrJitDescriptor(context); + if (!descriptor.has_value()) { + return std::nullopt; } - - return jit::HashAggrJitSlot{ - aggregateIndex, - *kind, - inputKind, - accumulatorKind, - aggregate.function->accumulatorOffset(), - aggregate.function->accumulatorNullByte(), - aggregate.function->accumulatorNullMask(), - countStar, - !isRawInput, - decimal}; + return aggregate.function->createHashAggrJitSlot(aggregateIndex, *descriptor); } std::string hashAggrJitSignature(const jit::HashAggrJitSlot& slot) { - return fmt::format( - "{}_{}_{}_{}_{}", - static_cast(slot.kind), - jit::hashAggrJitValueKindName(slot.inputKind), - jit::hashAggrJitValueKindName(slot.accumulatorKind), + return jit::HashAggrJitDescriptor{ + slot.kind, + slot.inputKind, + slot.accumulatorKind, + slot.countStar, slot.mergeInput, - slot.decimal); + slot.decimal} + .signature(); } -#endif } // namespace @@ -959,7 +843,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto minChunkWidth = std::max(minFuseWidth, compileMinCount); std::unordered_map> groups; for (auto i = 0; i < aggregates_.size(); ++i) { - auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_); + auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); if (!slot.has_value()) { continue; } diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index f8869255f..ec0237120 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -102,6 +102,54 @@ class AverageAggregateBase : public exec::Aggregate { return true; } +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.inputCount != 1 || !context.inputType) { + return false; + } + if (context.isRawInput) { + return context.inputType->isDecimal() || + jit::isHashAggrJitSupportedType(context.inputType->kind()) || + context.inputType->kind() == TypeKind::HUGEINT; + } + return context.inputType->isRow() && context.inputType->size() == 2 && + context.inputType->childAt(1)->kind() == TypeKind::BIGINT && + context.inputType->childAt(0)->kind() == TypeKind::DOUBLE; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + + if (!context.isRawInput) { + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Avg, + jit::HashAggrJitValueKind::Double, + jit::HashAggrJitValueKind::Double, + false, + true, + false}; + } + + const bool decimal = context.inputType->isDecimal(); + auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Avg, + *inputKind, + decimal ? jit::HashAggrJitValueKind::Int128 + : jit::HashAggrJitValueKind::Double, + false, + false, + decimal}; + } +#endif + FLATTEN void toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index acafb3766..ae9e35aef 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -63,6 +63,49 @@ class SumAggregateBase return true; } +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.inputCount != 1 || !context.inputType) { + return false; + } + if (context.inputType->isRow()) { + return false; + } + return context.inputType->isDecimal() || + jit::isHashAggrJitSupportedType(context.inputType->kind()) || + context.inputType->kind() == TypeKind::HUGEINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + + const bool decimal = context.isRawInput && context.inputType->isDecimal(); + auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + + auto accumulatorKind = decimal + ? jit::HashAggrJitValueKind::Int128 + : ((*inputKind == jit::HashAggrJitValueKind::Float || + *inputKind == jit::HashAggrJitValueKind::Double) + ? jit::HashAggrJitValueKind::Double + : jit::HashAggrJitValueKind::Int64); + + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Sum, + *inputKind, + accumulatorKind, + false, + !context.isRawInput, + decimal}; + } +#endif + void toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index da49ba53e..497d52d90 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -53,6 +53,42 @@ class CountAggregate : public SimpleNumericAggregate { return true; } +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.isRawInput) { + return context.inputCount == 0 || + (context.inputCount == 1 && context.inputType != nullptr && + !context.inputType->isRow() && !context.inputType->isDecimal() && + jit::isHashAggrJitSupportedType(context.inputType->kind())); + } + return context.inputCount == 1 && context.inputType != nullptr && + context.inputType->kind() == TypeKind::BIGINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + auto inputKind = jit::HashAggrJitValueKind::Int64; + if (!context.isCountStar()) { + auto maybeInputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!maybeInputKind.has_value()) { + return std::nullopt; + } + inputKind = *maybeInputKind; + } + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Count, + inputKind, + jit::HashAggrJitValueKind::Int64, + context.isCountStar(), + !context.isRawInput, + false}; + } +#endif + void toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 3fa9ed216..95fe410cf 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -52,6 +52,37 @@ class MinMaxAggregate : public SimpleNumericAggregate { public: explicit MinMaxAggregate(TypePtr resultType) : BaseAggregate(resultType) {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + return context.inputCount == 1 && context.inputType != nullptr && + !context.inputType->isRow() && !context.inputType->isDecimal() && + (jit::isHashAggrJitSupportedType(context.inputType->kind()) || + context.inputType->kind() == TypeKind::HUGEINT); + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + return jit::HashAggrJitDescriptor{ + jitKind(), + *inputKind, + *inputKind, + false, + !context.isRawInput, + false}; + } + + protected: + virtual jit::HashAggrJitKind jitKind() const = 0; +#endif + int32_t accumulatorFixedWidthSize() const override { return sizeof(T); } @@ -135,6 +166,15 @@ class MaxAggregate : public MinMaxAggregate { public: explicit MaxAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} +#ifdef ENABLE_BOLT_JIT + protected: + jit::HashAggrJitKind jitKind() const override { + return jit::HashAggrJitKind::Max; + } + + public: +#endif + void initializeNewGroups( char** groups, folly::Range indices) final { @@ -225,6 +265,15 @@ class MinAggregate : public MinMaxAggregate { public: explicit MinAggregate(TypePtr resultType) : MinMaxAggregate(resultType) {} +#ifdef ENABLE_BOLT_JIT + protected: + jit::HashAggrJitKind jitKind() const override { + return jit::HashAggrJitKind::Min; + } + + public: +#endif + void initializeNewGroups( char** groups, folly::Range indices) final { diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 7bca0a24f..f3c419f84 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -9,6 +9,8 @@ #include #include +#include + #include "bolt/jit/ThrustJITv2.h" extern "C" { @@ -742,6 +744,27 @@ std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { return "unknown"; } +std::optional hashAggrJitValueKind(TypeKind kind) { + switch (kind) { + case TypeKind::TINYINT: + return HashAggrJitValueKind::Int8; + case TypeKind::SMALLINT: + return HashAggrJitValueKind::Int16; + case TypeKind::INTEGER: + return HashAggrJitValueKind::Int32; + case TypeKind::BIGINT: + return HashAggrJitValueKind::Int64; + case TypeKind::HUGEINT: + return HashAggrJitValueKind::Int128; + case TypeKind::REAL: + return HashAggrJitValueKind::Float; + case TypeKind::DOUBLE: + return HashAggrJitValueKind::Double; + default: + return std::nullopt; + } +} + bool isHashAggrJitSupportedType(TypeKind kind) { switch (kind) { case TypeKind::TINYINT: @@ -756,6 +779,16 @@ bool isHashAggrJitSupportedType(TypeKind kind) { } } +std::string HashAggrJitDescriptor::signature() const { + return fmt::format( + "{}_{}_{}_{}_{}", + static_cast(kind), + hashAggrJitValueKindName(inputKind), + hashAggrJitValueKindName(accumulatorKind), + mergeInput, + decimal); +} + std::string HashAggrJitChunk::functionName() const { std::ostringstream out; out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 2c5305c25..ddc33cd74 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -12,6 +13,17 @@ namespace bytedance::bolt::jit { +struct HashAggrJitPlanContext { + bool isRawInput{false}; + bool isPartialOutput{false}; + int32_t inputCount{0}; + TypePtr inputType; + + bool isCountStar() const { + return isRawInput && inputCount == 0; + } +}; + enum class HashAggrJitKind : uint8_t { Count, Sum, @@ -30,6 +42,17 @@ enum class HashAggrJitValueKind : uint8_t { Double, }; +struct HashAggrJitDescriptor { + HashAggrJitKind kind; + HashAggrJitValueKind inputKind; + HashAggrJitValueKind accumulatorKind; + bool countStar{false}; + bool mergeInput{false}; + bool decimal{false}; + + std::string signature() const; +}; + struct HashAggrJitSlot { int32_t aggregateIndex; HashAggrJitKind kind; @@ -106,6 +129,7 @@ class HashAggrJitChunk { }; bool isHashAggrJitSupportedType(TypeKind kind); +std::optional hashAggrJitValueKind(TypeKind kind); std::string hashAggrJitValueKindName(HashAggrJitValueKind kind); } // namespace bytedance::bolt::jit From ad04ebc3548879a6e5f8a77931f8a4b586a42636 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 1 Jun 2026 14:08:27 +0800 Subject: [PATCH 06/98] refactor slot group logics --- bolt/exec/GroupingSet.cpp | 51 ++++++++----------- .../aggregates/tests/SumAggregationTest.cpp | 43 ++++++++++++++++ 2 files changed, 65 insertions(+), 29 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 1753a2af5..31d2c0f36 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -96,17 +96,6 @@ std::optional makeHashAggrJitSlot( return aggregate.function->createHashAggrJitSlot(aggregateIndex, *descriptor); } -std::string hashAggrJitSignature(const jit::HashAggrJitSlot& slot) { - return jit::HashAggrJitDescriptor{ - slot.kind, - slot.inputKind, - slot.accumulatorKind, - slot.countStar, - slot.mergeInput, - slot.decimal} - .signature(); -} - } // namespace GroupingSet::GroupingSet( @@ -841,32 +830,36 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto compileMinCount = std::max(1, queryConfig_.hashAggrJitCompileMinCount()); const auto minChunkWidth = std::max(minFuseWidth, compileMinCount); - std::unordered_map> groups; + std::vector currentChunkSlots; + currentChunkSlots.reserve(maxFuseWidth); + + auto flushChunk = [&]() { + if (currentChunkSlots.size() < minChunkWidth) { + currentChunkSlots.clear(); + return; + } + jit::HashAggrJitChunk chunk(std::move(currentChunkSlots), isPartial_); + if (chunk.codegen()) { + hashAggrJitChunks_.push_back(std::move(chunk)); + } + currentChunkSlots.clear(); + currentChunkSlots.reserve(maxFuseWidth); + }; + for (auto i = 0; i < aggregates_.size(); ++i) { auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); if (!slot.has_value()) { + flushChunk(); continue; } - groups[hashAggrJitSignature(*slot)].push_back(*slot); - } - for (auto& [_, slots] : groups) { - if (slots.size() < minChunkWidth) { - continue; - } - for (auto begin = 0; begin < slots.size(); begin += maxFuseWidth) { - const auto end = std::min(begin + maxFuseWidth, slots.size()); - std::vector chunkSlots( - slots.begin() + begin, slots.begin() + end); - if (chunkSlots.size() < minChunkWidth) { - continue; - } - jit::HashAggrJitChunk chunk(std::move(chunkSlots), isPartial_); - if (chunk.codegen()) { - hashAggrJitChunks_.push_back(std::move(chunk)); - } + if (currentChunkSlots.size() >= maxFuseWidth) { + flushChunk(); } + currentChunkSlots.push_back(*slot); } + + flushChunk(); } void GroupingSet::runHashAggrJitChunks( diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 46b3a34d7..a29da74f5 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -218,6 +218,49 @@ TEST_F(SumAggregationTest, hashAggrJitPartialAvgExtractAccumulators) { assertEqualResults({noJit}, {jit}); } +TEST_F(SumAggregationTest, hashAggrJitSplitsContiguousSegments) { + auto input = makeRowVector( + {makeFlatVector(512, [](auto row) { return row % 16; }), + makeFlatVector(512, [](auto row) { return row; }), + makeFlatVector(512, [](auto row) { return row * 2; }), + makeFlatVector(512, [](auto row) { return 1000 - row; }), + makeFlatVector(512, [](auto row) { return row * 5; }), + makeFlatVector(512, [](auto row) { return row * 7; }), + makeFlatVector(512, [](auto row) { return row * 11; }), + makeFlatVector(512, [](auto row) { return row * 13; }), + makeFlatVector(512, [](auto row) { return row % 9; }), + makeFlatVector(512, [](auto row) { return row % 17; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .singleAggregation( + {"c0"}, + {"min(c1)", + "max(c2)", + "spark_sum(c3)", + "spark_avg(c4)", + "min(c5)", + "max(c6)", + "spark_sum(c7)", + "spark_avg(c8)", + "spark_collect_list(c1)", + "spark_collect_list(c2)", + "min(c9)", + "max(c9)"}) + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") + .config(core::QueryConfig::kHashAggrJitMaxFuseWidth, "4") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + TEST_F(SumAggregationTest, decimalSum) { std::vector> shortDecimalRawVector; std::vector> longDecimalRawVector; From 921e377850d5c0b0950ba23b756f02f76524d6ff Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 1 Jun 2026 16:06:46 +0800 Subject: [PATCH 07/98] push down aggr func ir codegen from hashaggrjit to seperate aggr functions --- bolt/exec/Aggregate.cpp | 3 +- bolt/exec/GroupingSet.cpp | 2 +- .../lib/aggregates/AverageAggregateBase.h | 144 +++++++- .../lib/aggregates/SumAggregateBase.h | 98 ++++- .../prestosql/aggregates/CountAggregate.cpp | 72 +++- .../prestosql/aggregates/MinMaxAggregates.cpp | 91 ++++- bolt/jit/aggregation/HashAggrJit.cpp | 340 +++++++----------- bolt/jit/aggregation/HashAggrJit.h | 105 ++++++ 8 files changed, 633 insertions(+), 222 deletions(-) diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index 2a34c1d58..8de8ccc0f 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -345,7 +345,8 @@ jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( accumulatorNullMask(), descriptor.countStar, descriptor.mergeInput, - descriptor.decimal}; + descriptor.decimal, + descriptor.ops}; } #endif diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 31d2c0f36..99c6b54e9 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -90,7 +90,7 @@ std::optional makeHashAggrJitSlot( return std::nullopt; } auto descriptor = aggregate.function->createHashAggrJitDescriptor(context); - if (!descriptor.has_value()) { + if (!descriptor.has_value() || descriptor->ops == nullptr) { return std::nullopt; } return aggregate.function->createHashAggrJitSlot(aggregateIndex, *descriptor); diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index ec0237120..8b741130e 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -131,7 +131,8 @@ class AverageAggregateBase : public exec::Aggregate { jit::HashAggrJitValueKind::Double, false, true, - false}; + false, + hashAggrJitOps()}; } const bool decimal = context.inputType->isDecimal(); @@ -146,8 +147,147 @@ class AverageAggregateBase : public exec::Aggregate { : jit::HashAggrJitValueKind::Double, false, false, - decimal}; + decimal, + hashAggrJitOps()}; + } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + if (slot.decimal) { + codegen.builder().CreateCall( + codegen.module().getFunction("jit_HashAggrInitDecimalAvg"), + {group, codegen.builder().getInt32(slot.offset)}); + return; + } + codegen.storeValue( + group, + codegen.llvmType(slot.accumulatorKind), + slot.offset, + llvm::ConstantFP::get(codegen.llvmType(slot.accumulatorKind), 0.0)); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().getInt64(0)); + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + if (slot.mergeInput) { + codegen.clearAccumulatorNull(group, slot); + auto* sum = codegen.loadAvgMergeField( + decoded, row, 0, codegen.builder().getDoubleTy()); + auto* count = codegen.loadAvgMergeField( + decoded, row, 1, codegen.builder().getInt64Ty()); + auto* oldSum = codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getDoubleTy(), + slot.offset, + codegen.builder().CreateFAdd(oldSum, sum)); + auto* oldCount = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, count)); + return; + } + + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + if (slot.decimal) { + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalAvgI128" + : "jit_HashAggrUpdateDecimalAvgI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + rawValue, + slot.inputKind, + jit::HashAggrJitValueKind::Int128) + : rawValue}); + return; + } + + auto* value = codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldSum = codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + codegen.storeValue( + group, + codegen.llvmType(slot.accumulatorKind), + slot.offset, + codegen.builder().CreateFAdd(oldSum, value)); + auto* oldCount = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot& slot, + bool partialOutput) { + if (slot.decimal || slot.accumulatorKind == jit::HashAggrJitValueKind::Int128) { + return false; + } + return !partialOutput || slot.accumulatorKind == jit::HashAggrJitValueKind::Double; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + auto* sum = + codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + auto* count = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset + 8); + auto* isNull = codegen.builder().CreateZExt( + codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); + if (target.partialOutput) { + codegen.emitPartialAvgResult(target.resultVector, target.row, sum, count, isNull); + return; + } + auto* countIsZero = codegen.builder().CreateICmpEQ(count, codegen.builder().getInt64(0)); + auto* divisor = codegen.builder().CreateSIToFP(count, codegen.llvmType(slot.accumulatorKind)); + auto* value = codegen.builder().CreateFDiv(sum, divisor); + auto* finalIsNull = codegen.builder().CreateZExt(countIsZero, codegen.builder().getInt8Ty()); + codegen.emitFlatValue( + target.resultVector, + target.row, + slot.accumulatorKind, + value, + finalIsNull); + } + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "avg", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } + + public: #endif FLATTEN void toIntermediate( diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index ae9e35aef..7cb2fb3cd 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -102,8 +102,104 @@ class SumAggregateBase accumulatorKind, false, !context.isRawInput, - decimal}; + decimal, + hashAggrJitOps()}; } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + if (slot.decimal) { + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateCall( + codegen.module().getFunction( + slot.kind == jit::HashAggrJitKind::Sum + ? "jit_HashAggrInitDecimalSum" + : "jit_HashAggrInitDecimalAvg"), + {group, codegen.builder().getInt32(slot.offset)}); + return; + } + codegen.setAccumulatorNull(group, slot); + auto* accType = codegen.llvmType(slot.accumulatorKind); + if (codegen.isFloatKind(slot.accumulatorKind)) { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); + } else { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantInt::get(accType, 0)); + } + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + if (slot.decimal) { + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalSumI128" + : "jit_HashAggrUpdateDecimalSumI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + rawValue, + slot.inputKind, + jit::HashAggrJitValueKind::Int128) + : rawValue}); + return; + } + auto* value = + codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); + auto* accType = codegen.llvmType(slot.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldValue = codegen.loadValue(group, accType, slot.offset); + auto* newValue = codegen.isFloatKind(slot.accumulatorKind) + ? codegen.builder().CreateFAdd(oldValue, value) + : codegen.builder().CreateAdd(oldValue, value); + codegen.storeValue(group, accType, slot.offset, newValue); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot& slot, + bool) { + return !slot.decimal && + slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue( + group, codegen.llvmType(slot.accumulatorKind), slot.offset); + auto* isNull = codegen.builder().CreateZExt( + codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); + codegen.emitFlatValue( + target.resultVector, target.row, slot.accumulatorKind, value, isNull); + } + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "sum", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } + + public: #endif void toIntermediate( diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 497d52d90..bc7a0b2a4 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -85,8 +85,78 @@ class CountAggregate : public SimpleNumericAggregate { jit::HashAggrJitValueKind::Int64, context.isCountStar(), !context.isRawInput, - false}; + false, + hashAggrJitOps()}; } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset, + codegen.builder().getInt64(0)); + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* state = codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); + llvm::Value* inc = nullptr; + if (slot.countStar || !slot.mergeInput) { + inc = codegen.builder().getInt64(1); + } else { + inc = codegen.castValue( + codegen.loadDecodedValue(decoded, row, slot), + slot.inputKind, + jit::HashAggrJitValueKind::Int64); + } + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset, + codegen.builder().CreateAdd(state, inc)); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot&, + bool) { + return true; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); + codegen.emitFlatValue( + target.resultVector, + target.row, + jit::HashAggrJitValueKind::Int64, + value, + codegen.builder().getInt8(0)); + } + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "count", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } + + public: #endif void toIntermediate( diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 95fe410cf..be194eebb 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -76,11 +76,100 @@ class MinMaxAggregate : public SimpleNumericAggregate { *inputKind, false, !context.isRawInput, - false}; + false, + hashAggrJitOps()}; + } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + auto* type = codegen.llvmType(slot.accumulatorKind); + if (codegen.isFloatKind(slot.accumulatorKind)) { + codegen.storeValue(group, type, slot.offset, llvm::ConstantFP::get(type, 0.0)); + } else { + codegen.storeValue(group, type, slot.offset, llvm::ConstantInt::get(type, 0)); + } + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* value = codegen.castValue( + codegen.loadDecodedValue(decoded, row, slot), + slot.inputKind, + slot.accumulatorKind); + auto* type = codegen.llvmType(slot.accumulatorKind); + auto* oldValue = codegen.loadValue(group, type, slot.offset); + auto* nullState = codegen.isAccumulatorNull(group, slot); + llvm::Value* better = nullptr; + if (codegen.isFloatKind(slot.accumulatorKind)) { + auto* oldIsNan = codegen.builder().CreateFCmpUNO(oldValue, oldValue); + auto* valueIsNan = codegen.builder().CreateFCmpUNO(value, value); + if (slot.kind == jit::HashAggrJitKind::Min) { + better = codegen.builder().CreateOr( + codegen.builder().CreateAnd(oldIsNan, codegen.builder().CreateNot(valueIsNan)), + codegen.builder().CreateAnd( + codegen.builder().CreateNot(valueIsNan), + codegen.builder().CreateFCmpOGT(oldValue, value))); + } else { + better = codegen.builder().CreateAnd( + codegen.builder().CreateNot(oldIsNan), + codegen.builder().CreateOr( + valueIsNan, codegen.builder().CreateFCmpOLT(oldValue, value))); + } + } else { + better = slot.kind == jit::HashAggrJitKind::Min + ? codegen.builder().CreateICmpSLT(value, oldValue) + : codegen.builder().CreateICmpSGT(value, oldValue); + } + auto* shouldStore = codegen.builder().CreateOr(nullState, better); + codegen.storeValue( + group, + type, + slot.offset, + codegen.builder().CreateSelect(shouldStore, value, oldValue)); + codegen.clearAccumulatorNull(group, slot); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot& slot, + bool) { + return slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + auto* isNull = codegen.builder().CreateZExt( + codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); + codegen.emitFlatValue( + target.resultVector, target.row, slot.accumulatorKind, value, isNull); + } + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "minmax", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; } protected: virtual jit::HashAggrJitKind jitKind() const = 0; + public: #endif int32_t accumulatorFixedWidthSize() const override { diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index f3c419f84..ad1ec80e3 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -199,6 +199,8 @@ std::string decodedValueFunction(HashAggrJitValueKind kind) { return "jit_GetDecodedValueI64"; } +std::string setFlatValueFunction(HashAggrJitValueKind kind); + bool isFloatKind(HashAggrJitValueKind kind) { return kind == HashAggrJitValueKind::Float || kind == HashAggrJitValueKind::Double; @@ -300,145 +302,113 @@ llvm::Value* loadDecodedValue( return builder.CreateCall(callee, {decoded, row}); } -void genCountUpdate( - llvm::IRBuilder<>& builder, +} // namespace + +HashAggrJitCodegen::HashAggrJitCodegen(llvm::Module& module) : module_(module) { + ensureBuiltinDeclarations(module_); +} + +llvm::Type* HashAggrJitCodegen::llvmType(HashAggrJitValueKind kind) const { + return ::bytedance::bolt::jit::llvmType(builder(), kind); +} + +llvm::Value* HashAggrJitCodegen::loadDecodedValue( + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot) const { + return ::bytedance::bolt::jit::loadDecodedValue( + builder(), module_, decoded, row, slot); +} + +llvm::Value* HashAggrJitCodegen::isAccumulatorNull( llvm::Value* group, - const HashAggrJitSlot& slot) { - auto* state = loadValue(builder, group, builder.getInt64Ty(), slot.offset); - storeValue( - builder, - group, - builder.getInt64Ty(), - slot.offset, - builder.CreateAdd(state, builder.getInt64(1))); + const HashAggrJitSlot& slot) const { + return ::bytedance::bolt::jit::isAccumulatorNull(builder(), group, slot); } -void genNonNullUpdate( - llvm::IRBuilder<>& builder, - llvm::Module& module, +void HashAggrJitCodegen::clearAccumulatorNull( llvm::Value* group, - llvm::Value* rawValue, - const HashAggrJitSlot& slot) { - auto* accType = llvmType(builder, slot.accumulatorKind); - auto* value = castValue(builder, rawValue, slot.inputKind, slot.accumulatorKind); - if (slot.decimal) { - clearAccumulatorNull(builder, group, slot); - const auto helper = slot.kind == HashAggrJitKind::Sum - ? (slot.inputKind == HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalSumI128" - : "jit_HashAggrUpdateDecimalSumI64") - : (slot.inputKind == HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalAvgI128" - : "jit_HashAggrUpdateDecimalAvgI64"); - builder.CreateCall( - module.getFunction(helper), - {group, - builder.getInt32(slot.offset), - slot.inputKind == HashAggrJitValueKind::Int128 ? value : rawValue}); - return; - } - switch (slot.kind) { - case HashAggrJitKind::Sum: { - clearAccumulatorNull(builder, group, slot); - auto* oldValue = loadValue(builder, group, accType, slot.offset); - auto* newValue = isFloatKind(slot.accumulatorKind) - ? builder.CreateFAdd(oldValue, value) - : builder.CreateAdd(oldValue, value); - storeValue(builder, group, accType, slot.offset, newValue); - break; - } - case HashAggrJitKind::Avg: { - clearAccumulatorNull(builder, group, slot); - auto* oldSum = loadValue(builder, group, accType, slot.offset); - auto* newSum = builder.CreateFAdd(oldSum, value); - storeValue(builder, group, accType, slot.offset, newSum); - auto* oldCount = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); - storeValue( - builder, - group, - builder.getInt64Ty(), - slot.offset + 8, - builder.CreateAdd(oldCount, builder.getInt64(1))); - break; - } - case HashAggrJitKind::Min: - case HashAggrJitKind::Max: { - auto* oldValue = loadValue(builder, group, accType, slot.offset); - auto* nullState = isAccumulatorNull(builder, group, slot); - llvm::Value* better; - if (isFloatKind(slot.accumulatorKind)) { - auto* oldIsNan = builder.CreateFCmpUNO(oldValue, oldValue); - auto* valueIsNan = builder.CreateFCmpUNO(value, value); - if (slot.kind == HashAggrJitKind::Min) { - better = builder.CreateOr( - builder.CreateAnd(oldIsNan, builder.CreateNot(valueIsNan)), - builder.CreateAnd( - builder.CreateNot(valueIsNan), - builder.CreateFCmpOGT(oldValue, value))); - } else { - better = builder.CreateAnd( - builder.CreateNot(oldIsNan), - builder.CreateOr(valueIsNan, builder.CreateFCmpOLT(oldValue, value))); - } - } else { - better = slot.kind == HashAggrJitKind::Min - ? builder.CreateICmpSLT(value, oldValue) - : builder.CreateICmpSGT(value, oldValue); - } - auto* shouldStore = builder.CreateOr(nullState, better); - auto* selected = builder.CreateSelect(shouldStore, value, oldValue); - storeValue(builder, group, accType, slot.offset, selected); - clearAccumulatorNull(builder, group, slot); - break; - } - case HashAggrJitKind::Count: - if (slot.mergeInput) { - auto* state = loadValue(builder, group, builder.getInt64Ty(), slot.offset); - storeValue( - builder, - group, - builder.getInt64Ty(), - slot.offset, - builder.CreateAdd(state, castValue(builder, rawValue, slot.inputKind, HashAggrJitValueKind::Int64))); - } else { - genCountUpdate(builder, group, slot); - } - break; - } + const HashAggrJitSlot& slot) const { + ::bytedance::bolt::jit::clearAccumulatorNull(builder(), group, slot); } -void genAvgMergeUpdate( - llvm::IRBuilder<>& builder, - llvm::Module& module, +void HashAggrJitCodegen::setAccumulatorNull( llvm::Value* group, + const HashAggrJitSlot& slot) const { + ::bytedance::bolt::jit::setAccumulatorNull(builder(), group, slot); +} + +llvm::LoadInst* HashAggrJitCodegen::loadValue( + llvm::Value* row, + llvm::Type* type, + int32_t offset) const { + return ::bytedance::bolt::jit::loadValue(builder(), row, type, offset); +} + +void HashAggrJitCodegen::storeValue( + llvm::Value* row, + llvm::Type* type, + int32_t offset, + llvm::Value* value) const { + ::bytedance::bolt::jit::storeValue(builder(), row, type, offset, value); +} + +llvm::Value* HashAggrJitCodegen::castValue( + llvm::Value* value, + HashAggrJitValueKind from, + HashAggrJitValueKind to) const { + return ::bytedance::bolt::jit::castValue(builder(), value, from, to); +} + +bool HashAggrJitCodegen::isFloatKind(HashAggrJitValueKind kind) const { + return ::bytedance::bolt::jit::isFloatKind(kind); +} + +llvm::Value* HashAggrJitCodegen::loadAvgMergeField( llvm::Value* decoded, llvm::Value* row, - const HashAggrJitSlot& slot) { - clearAccumulatorNull(builder, group, slot); - auto* sum = builder.CreateCall( - module.getFunction("jit_GetDecodedRowFieldDouble"), - {decoded, row, builder.getInt32(0)}); - auto* count = builder.CreateCall( - module.getFunction("jit_GetDecodedRowFieldI64"), - {decoded, row, builder.getInt32(1)}); - - auto* oldSum = loadValue(builder, group, builder.getDoubleTy(), slot.offset); - storeValue( - builder, - group, - builder.getDoubleTy(), - slot.offset, - builder.CreateFAdd(oldSum, sum)); + int32_t field, + llvm::Type* type) const { + const char* name = type->isDoubleTy() ? "jit_GetDecodedRowFieldDouble" + : "jit_GetDecodedRowFieldI64"; + return builder().CreateCall( + module_.getFunction(name), {decoded, row, builder().getInt32(field)}); +} - auto* oldCount = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); - storeValue( - builder, - group, - builder.getInt64Ty(), - slot.offset + 8, - builder.CreateAdd(oldCount, count)); +void HashAggrJitCodegen::emitFlatValue( + llvm::Value* vector, + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* value, + llvm::Value* isNull) const { + const auto setter = setFlatValueFunction(kind); + if (setter.empty()) { + return; + } + builder().CreateCall( + module_.getFunction(setter), {vector, row, value, isNull}); +} + +void HashAggrJitCodegen::resizeResultVector( + llvm::Value* vector, + llvm::Value* size) const { + builder().CreateCall( + module_.getFunction("jit_HashAggrResizeVector"), {vector, size}); +} + +void HashAggrJitCodegen::emitPartialAvgResult( + llvm::Value* vector, + llvm::Value* row, + llvm::Value* sum, + llvm::Value* count, + llvm::Value* isNull) const { + builder().CreateCall( + module_.getFunction("jit_HashAggrSetPartialAvgDouble"), + {vector, row, sum, count, isNull}); } +namespace { + bool genAddDenseIR( llvm::Module& module, const std::string& fn, @@ -449,9 +419,10 @@ bool genInitIR( llvm::Module& module, const std::string& fn, const std::vector& slots) { - ensureBuiltinDeclarations(module); auto& context = module.getContext(); llvm::IRBuilder<> builder(context); + HashAggrJitCodegen codegen(module); + codegen.setBuilder(&builder); auto* voidTy = builder.getVoidTy(); auto* i8PtrTy = llvm::PointerType::get(context, 0); auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); @@ -479,41 +450,10 @@ bool genInitIR( auto* group = builder.CreateLoad(i8PtrTy, groupAddr); for (const auto& slot : slots) { - if (slot.kind != HashAggrJitKind::Count) { - setAccumulatorNull(builder, group, slot); - } - if (slot.decimal) { - builder.CreateCall( - module.getFunction( - slot.kind == HashAggrJitKind::Sum ? "jit_HashAggrInitDecimalSum" - : "jit_HashAggrInitDecimalAvg"), - {group, builder.getInt32(slot.offset)}); - continue; - } - auto* accType = llvmType(builder, slot.accumulatorKind); - if (isFloatKind(slot.accumulatorKind)) { - storeValue( - builder, - group, - accType, - slot.offset, - llvm::ConstantFP::get(accType, 0.0)); - } else { - storeValue( - builder, - group, - accType, - slot.offset, - llvm::ConstantInt::get(accType, 0)); - } - if (slot.kind == HashAggrJitKind::Avg) { - storeValue( - builder, - group, - builder.getInt64Ty(), - slot.offset + 8, - builder.getInt64(0)); + if (slot.ops == nullptr || slot.ops->create == nullptr) { + return true; } + slot.ops->create(codegen, group, slot); } auto* next = builder.CreateAdd(index, builder.getInt32(1)); @@ -531,9 +471,10 @@ bool genAddDenseIR( const std::string& fn, const std::vector& slots, bool checkInputNulls) { - ensureBuiltinDeclarations(module); auto& context = module.getContext(); llvm::IRBuilder<> builder(context); + HashAggrJitCodegen codegen(module); + codegen.setBuilder(&builder); auto* voidTy = builder.getVoidTy(); auto* i8PtrTy = llvm::PointerType::get(context, 0); auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); @@ -562,16 +503,11 @@ bool genAddDenseIR( for (auto i = 0; i < slots.size(); ++i) { const auto& slot = slots[i]; - if (slot.kind == HashAggrJitKind::Count && slot.countStar) { - genCountUpdate(builder, group, slot); - continue; - } - auto* updateBlock = llvm::BasicBlock::Create(context, "slot_update", func, end); auto* nextBlock = llvm::BasicBlock::Create(context, "slot_next", func, end); auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); - if (checkInputNulls) { + if (checkInputNulls && !slot.countStar) { auto* isNull = builder.CreateICmpNE( builder.CreateCall(module.getFunction("jit_GetDecodedIsNull"), {decoded, row}), builder.getInt8(0)); @@ -581,14 +517,10 @@ bool genAddDenseIR( } builder.SetInsertPoint(updateBlock); - if (slot.kind == HashAggrJitKind::Count && !slot.mergeInput) { - genCountUpdate(builder, group, slot); - } else if (slot.kind == HashAggrJitKind::Avg && slot.mergeInput) { - genAvgMergeUpdate(builder, module, group, decoded, row, slot); - } else { - auto* value = loadDecodedValue(builder, module, decoded, row, slot); - genNonNullUpdate(builder, module, group, value, slot); + if (slot.ops == nullptr || slot.ops->add == nullptr) { + return true; } + slot.ops->add(codegen, group, decoded, row, slot, checkInputNulls, nextBlock); builder.CreateBr(nextBlock); builder.SetInsertPoint(nextBlock); } @@ -628,9 +560,10 @@ bool genExtractIR( const std::string& fn, const std::vector& slots, bool partialOutput) { - ensureBuiltinDeclarations(module); auto& context = module.getContext(); llvm::IRBuilder<> builder(context); + HashAggrJitCodegen codegen(module); + codegen.setBuilder(&builder); auto* voidTy = builder.getVoidTy(); auto* i8PtrTy = llvm::PointerType::get(context, 0); auto* i8PtrPtrTy = i8PtrTy->getPointerTo(); @@ -650,12 +583,13 @@ bool genExtractIR( auto* end = llvm::BasicBlock::Create(context, "end", func); builder.SetInsertPoint(entry); for (auto i = 0; i < slots.size(); ++i) { - if (slots[i].decimal || slots[i].accumulatorKind == HashAggrJitValueKind::Int128) { + if (slots[i].ops == nullptr || slots[i].ops->canExtract == nullptr || + !slots[i].ops->canExtract(slots[i], partialOutput)) { continue; } auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); - builder.CreateCall(module.getFunction("jit_HashAggrResizeVector"), {vector, numGroups}); + codegen.resizeResultVector(vector, numGroups); } builder.CreateCondBr(builder.CreateICmpSLE(numGroups, builder.getInt32(0)), end, loop); @@ -667,44 +601,17 @@ bool genExtractIR( for (auto i = 0; i < slots.size(); ++i) { const auto& slot = slots[i]; - if (slot.decimal || slot.accumulatorKind == HashAggrJitValueKind::Int128) { + if (slot.ops == nullptr || slot.ops->canExtract == nullptr || + !slot.ops->canExtract(slot, partialOutput)) { continue; } auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); - HashAggrJitValueKind resultKind = slot.accumulatorKind; - llvm::Value* value = nullptr; - llvm::Value* isNull = nullptr; - if (partialOutput && slot.kind == HashAggrJitKind::Avg) { - auto* sum = loadValue( - builder, group, llvmType(builder, slot.accumulatorKind), slot.offset); - auto* count = - loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); - auto* isNullValue = builder.CreateZExt( - isAccumulatorNull(builder, group, slot), builder.getInt8Ty()); - builder.CreateCall( - module.getFunction("jit_HashAggrSetPartialAvgDouble"), - {vector, row, sum, count, isNullValue}); - continue; - } - if (slot.kind == HashAggrJitKind::Avg) { - auto* sum = loadValue(builder, group, llvmType(builder, slot.accumulatorKind), slot.offset); - auto* count = loadValue(builder, group, builder.getInt64Ty(), slot.offset + 8); - auto* countIsZero = builder.CreateICmpEQ(count, builder.getInt64(0)); - auto* divisor = builder.CreateSIToFP(count, llvmType(builder, slot.accumulatorKind)); - value = builder.CreateFDiv(sum, divisor); - isNull = builder.CreateZExt(countIsZero, builder.getInt8Ty()); - } else { - value = loadValue(builder, group, llvmType(builder, resultKind), slot.offset); - isNull = slot.kind == HashAggrJitKind::Count - ? builder.getInt8(0) - : builder.CreateZExt(isAccumulatorNull(builder, group, slot), builder.getInt8Ty()); - } - const auto setter = setFlatValueFunction(resultKind); - if (setter.empty()) { - continue; + if (slot.ops->extract == nullptr) { + return true; } - builder.CreateCall(module.getFunction(setter), {vector, row, value, isNull}); + slot.ops->extract( + codegen, group, slot, HashAggrJitExtractTarget{vector, row, partialOutput}); } auto* next = builder.CreateAdd(row, builder.getInt32(1)); @@ -781,7 +688,8 @@ bool isHashAggrJitSupportedType(TypeKind kind) { std::string HashAggrJitDescriptor::signature() const { return fmt::format( - "{}_{}_{}_{}_{}", + "{}_{}_{}_{}_{}_{}", + ops != nullptr ? ops->id : "unknown", static_cast(kind), hashAggrJitValueKindName(inputKind), hashAggrJitValueKindName(accumulatorKind), @@ -794,7 +702,8 @@ std::string HashAggrJitChunk::functionName() const { out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" << slots_.size(); for (const auto& slot : slots_) { - out << "_" << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) + out << "_" << (slot.ops != nullptr ? slot.ops->id : "unknown") << "_" + << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) << hashAggrJitValueKindName(slot.accumulatorKind) << "o" << slot.offset << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) << (slot.countStar ? "s" : "x") << (slot.mergeInput ? "g" : "r") @@ -808,7 +717,8 @@ bool HashAggrJitChunk::canExtract() const { return false; } for (const auto& slot : slots_) { - if (slot.decimal || slot.accumulatorKind == HashAggrJitValueKind::Int128) { + if (slot.ops == nullptr || slot.ops->canExtract == nullptr || + !slot.ops->canExtract(slot, partialOutput_)) { return false; } } diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index ddc33cd74..d964e3319 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -8,11 +8,19 @@ #include #include +#include +#include +#include + #include "bolt/jit/CompiledModule.h" #include "bolt/type/Type.h" namespace bytedance::bolt::jit { +class HashAggrJitCodegen; +struct HashAggrJitSlot; +struct HashAggrJitExtractTarget; + struct HashAggrJitPlanContext { bool isRawInput{false}; bool isPartialOutput{false}; @@ -49,10 +57,36 @@ struct HashAggrJitDescriptor { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; + const struct HashAggrJitOps* ops{nullptr}; std::string signature() const; }; +struct HashAggrJitOps { + using CreateFn = + void (*)(HashAggrJitCodegen&, llvm::Value* group, const HashAggrJitSlot&); + using AddFn = void (*)( + HashAggrJitCodegen&, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot&, + bool checkInputNulls, + llvm::BasicBlock* nextBlock); + using CanExtractFn = bool (*)(const HashAggrJitSlot&, bool partialOutput); + using ExtractFn = void (*)( + HashAggrJitCodegen&, + llvm::Value* group, + const HashAggrJitSlot&, + const HashAggrJitExtractTarget&); + + const char* id; + CreateFn create; + AddFn add; + CanExtractFn canExtract; + ExtractFn extract; +}; + struct HashAggrJitSlot { int32_t aggregateIndex; HashAggrJitKind kind; @@ -64,6 +98,77 @@ struct HashAggrJitSlot { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; + const HashAggrJitOps* ops{nullptr}; +}; + +struct HashAggrJitExtractTarget { + llvm::Value* resultVector; + llvm::Value* row; + bool partialOutput; +}; + +class HashAggrJitCodegen { + public: + explicit HashAggrJitCodegen(llvm::Module& module); + + llvm::Module& module() const { + return module_; + } + + llvm::IRBuilder<>& builder() const { + return *builder_; + } + + void setBuilder(llvm::IRBuilder<>* builder) { + builder_ = builder; + } + + llvm::Type* llvmType(HashAggrJitValueKind kind) const; + llvm::Value* loadDecodedValue( + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot) const; + llvm::Value* isAccumulatorNull( + llvm::Value* group, + const HashAggrJitSlot& slot) const; + void clearAccumulatorNull(llvm::Value* group, const HashAggrJitSlot& slot) + const; + void setAccumulatorNull(llvm::Value* group, const HashAggrJitSlot& slot) + const; + llvm::LoadInst* loadValue(llvm::Value* row, llvm::Type* type, int32_t offset) + const; + void storeValue( + llvm::Value* row, + llvm::Type* type, + int32_t offset, + llvm::Value* value) const; + llvm::Value* castValue( + llvm::Value* value, + HashAggrJitValueKind from, + HashAggrJitValueKind to) const; + bool isFloatKind(HashAggrJitValueKind kind) const; + llvm::Value* loadAvgMergeField( + llvm::Value* decoded, + llvm::Value* row, + int32_t field, + llvm::Type* type) const; + void emitFlatValue( + llvm::Value* vector, + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* value, + llvm::Value* isNull) const; + void resizeResultVector(llvm::Value* vector, llvm::Value* size) const; + void emitPartialAvgResult( + llvm::Value* vector, + llvm::Value* row, + llvm::Value* sum, + llvm::Value* count, + llvm::Value* isNull) const; + + private: + llvm::Module& module_; + llvm::IRBuilder<>* builder_{nullptr}; }; using HashAggrJitAddDenseFunc = void (*)(char** groups, int32_t numRows, char** decodedInputs); From 0aa3cf4044bf3f10c6061cbacf6f683315bb93d4 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 3 Jun 2026 20:14:53 +0800 Subject: [PATCH 08/98] temperarily disable extractValues jit --- bolt/functions/lib/aggregates/AverageAggregateBase.h | 11 +++++++---- bolt/functions/lib/aggregates/SumAggregateBase.h | 5 +++-- .../functions/prestosql/aggregates/CountAggregate.cpp | 2 +- .../prestosql/aggregates/MinMaxAggregates.cpp | 3 ++- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index 8b741130e..8ced96040 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -244,10 +244,13 @@ class AverageAggregateBase : public exec::Aggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot& slot, bool partialOutput) { - if (slot.decimal || slot.accumulatorKind == jit::HashAggrJitValueKind::Int128) { - return false; - } - return !partialOutput || slot.accumulatorKind == jit::HashAggrJitValueKind::Double; + // if (slot.decimal || slot.accumulatorKind == + // jit::HashAggrJitValueKind::Int128) { + // return false; + // } + // return !partialOutput || slot.accumulatorKind == + // jit::HashAggrJitValueKind::Double; + return false; } static void compileHashAggrJitExtract( diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index 7cb2fb3cd..418d6db8f 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -172,8 +172,9 @@ class SumAggregateBase static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot& slot, bool) { - return !slot.decimal && - slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; + // return !slot.decimal && + // slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; + return false; } static void compileHashAggrJitExtract( diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index bc7a0b2a4..d253e44eb 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -129,7 +129,7 @@ class CountAggregate : public SimpleNumericAggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { - return true; + return false; } static void compileHashAggrJitExtract( diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index be194eebb..9e1fa7692 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -142,7 +142,8 @@ class MinMaxAggregate : public SimpleNumericAggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot& slot, bool) { - return slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; + // return slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; + return false; } static void compileHashAggrJitExtract( From 3ea7e2bcbd1e831dca47f9919c3667bc24427155 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 3 Jun 2026 21:25:12 +0800 Subject: [PATCH 09/98] modify build version From ecdfa7621d309a533ad910eb9e7e303d9bdcb45d Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 4 Jun 2026 11:00:41 +0800 Subject: [PATCH 10/98] add debug logs --- bolt/exec/GroupingSet.cpp | 143 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 99c6b54e9..b11262d42 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -29,6 +29,7 @@ */ #include "bolt/exec/GroupingSet.h" +#include #include "bolt/common/base/Exceptions.h" #include "bolt/common/base/SpillConfig.h" #include "bolt/common/testutil/TestValue.h" @@ -62,6 +63,56 @@ bool areAllLazyNotLoaded(const std::vector& vectors) { }); } +#ifdef ENABLE_BOLT_JIT +std::string hashAggrJitTypeName(const TypePtr& type) { + return type == nullptr ? "null" : type->toString(); +} + +std::string hashAggrJitSlotDebugString( + const jit::HashAggrJitSlot& slot, + const AggregateInfo* aggregate = nullptr) { + std::ostringstream out; + out << "agg#" << slot.aggregateIndex; + if (aggregate != nullptr) { + out << "(" << aggregate->name << ")"; + out << " inputs=["; + for (size_t i = 0; i < aggregate->rawInputTypes.size(); ++i) { + if (i > 0) { + out << ", "; + } + out << hashAggrJitTypeName(aggregate->rawInputTypes[i]); + } + out << "]"; + } + out << " kind=" << static_cast(slot.kind) + << " inputKind=" << jit::hashAggrJitValueKindName(slot.inputKind) + << " accKind=" << jit::hashAggrJitValueKindName(slot.accumulatorKind) + << " offset=" << slot.offset << " nullByte=" << slot.nullByte + << " nullMask=" << static_cast(slot.nullMask) + << " countStar=" << slot.countStar + << " mergeInput=" << slot.mergeInput << " decimal=" << slot.decimal + << " ops=" << (slot.ops != nullptr ? slot.ops->id : "null"); + return out.str(); +} + +std::string hashAggrJitChunkDebugString( + const jit::HashAggrJitChunk& chunk, + const std::vector& aggregates) { + std::ostringstream out; + out << chunk.functionName() << " slots=["; + for (size_t i = 0; i < chunk.slots().size(); ++i) { + if (i > 0) { + out << "; "; + } + const auto& slot = chunk.slots()[i]; + out << hashAggrJitSlotDebugString(slot, &aggregates[slot.aggregateIndex]); + } + out << "] canExtract=" << chunk.canExtract() + << " enabled=" << chunk.enabled(); + return out.str(); +} +#endif + std::optional makeHashAggrJitSlot( int32_t aggregateIndex, const AggregateInfo& aggregate, @@ -822,6 +873,9 @@ const SelectivityVector& GroupingSet::getSelectivityVector( void GroupingSet::maybeCreateHashAggrJitPlan() { hashAggrJitChunks_.clear(); if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { + LOG(INFO) << "HashAggrJit plan disabled: enableHashAggrJit=" + << queryConfig_.enableHashAggrJit() << " isGlobal=" << isGlobal_ + << " ignoreNullKeys=" << ignoreNullKeys_; return; } @@ -830,17 +884,43 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto compileMinCount = std::max(1, queryConfig_.hashAggrJitCompileMinCount()); const auto minChunkWidth = std::max(minFuseWidth, compileMinCount); + LOG(INFO) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ + << " isPartial=" << isPartial_ + << " aggregates=" << aggregates_.size() + << " minFuseWidth=" << minFuseWidth + << " maxFuseWidth=" << maxFuseWidth + << " compileMinCount=" << compileMinCount + << " minChunkWidth=" << minChunkWidth; std::vector currentChunkSlots; currentChunkSlots.reserve(maxFuseWidth); auto flushChunk = [&]() { if (currentChunkSlots.size() < minChunkWidth) { + if (!currentChunkSlots.empty()) { + std::ostringstream out; + for (size_t i = 0; i < currentChunkSlots.size(); ++i) { + if (i > 0) { + out << "; "; + } + const auto& slot = currentChunkSlots[i]; + out << hashAggrJitSlotDebugString(slot, &aggregates_[slot.aggregateIndex]); + } + LOG(INFO) << "HashAggrJit discard chunk candidate due to width " + << currentChunkSlots.size() << " < " << minChunkWidth + << ": [" << out.str() << "]"; + } currentChunkSlots.clear(); return; } jit::HashAggrJitChunk chunk(std::move(currentChunkSlots), isPartial_); if (chunk.codegen()) { hashAggrJitChunks_.push_back(std::move(chunk)); + LOG(INFO) << "HashAggrJit formed chunk: " + << hashAggrJitChunkDebugString( + hashAggrJitChunks_.back(), aggregates_); + } else { + LOG(INFO) << "HashAggrJit chunk codegen failed for chunk " + << chunk.functionName(); } currentChunkSlots.clear(); currentChunkSlots.reserve(maxFuseWidth); @@ -849,6 +929,24 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { for (auto i = 0; i < aggregates_.size(); ++i) { auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); if (!slot.has_value()) { + LOG(INFO) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" + << aggregates_[i].name << ") rawInputTypes=[" + << [&]() { + std::ostringstream out; + for (size_t j = 0; j < aggregates_[i].rawInputTypes.size(); ++j) { + if (j > 0) { + out << ", "; + } + out << hashAggrJitTypeName(aggregates_[i].rawInputTypes[j]); + } + return out.str(); + }() + << "] distinct=" << aggregates_[i].distinct + << " mask=" << aggregates_[i].mask.has_value() + << " sortingKeys=" << aggregates_[i].sortingKeys.size() + << " inputs=" << aggregates_[i].inputs.size() + << " intermediateType=" + << hashAggrJitTypeName(aggregates_[i].intermediateType); flushChunk(); continue; } @@ -856,10 +954,14 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { if (currentChunkSlots.size() >= maxFuseWidth) { flushChunk(); } + LOG(INFO) << "HashAggrJit aggregate is JIT-able: " + << hashAggrJitSlotDebugString(*slot, &aggregates_[i]); currentChunkSlots.push_back(*slot); } flushChunk(); + LOG(INFO) << "HashAggrJit planning finished: totalChunks=" + << hashAggrJitChunks_.size(); } void GroupingSet::runHashAggrJitChunks( @@ -870,12 +972,19 @@ void GroupingSet::runHashAggrJitChunks( std::vector& jitExecuted) { if (hashAggrJitChunks_.empty() || hasSpilled() || bypassProbeHT_ || supportRowBasedOutput_ || !activeRows_.isAllSelected()) { + LOG(INFO) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() + << " hasSpilled=" << hasSpilled() + << " bypassProbeHT=" << bypassProbeHT_ + << " supportRowBasedOutput=" << supportRowBasedOutput_ + << " activeRowsAllSelected=" << activeRows_.isAllSelected(); return; } jitExecuted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { if (!chunk.enabled()) { + LOG(INFO) << "HashAggrJit chunk disabled, skip add: " + << hashAggrJitChunkDebugString(chunk, aggregates_); continue; } @@ -885,6 +994,7 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitDecodedPtrs_.assign(numSlots, nullptr); bool canRunChunk = true; + std::string skipReason; bool inputsMayHaveNulls = false; for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { const auto& slot = chunk.slots()[slotIndex]; @@ -892,11 +1002,13 @@ void GroupingSet::runHashAggrJitChunks( if (aggregate.mask.has_value() || aggregate.distinct || !aggregate.sortingKeys.empty()) { canRunChunk = false; + skipReason = "mask/distinct/sortingKeys not supported"; break; } const auto& rows = getSelectivityVector(slot.aggregateIndex); if (&rows != &activeRows_ || !rows.hasSelections()) { canRunChunk = false; + skipReason = "selectivity vector is not dense activeRows or has no selections"; break; } if (slot.countStar) { @@ -904,6 +1016,7 @@ void GroupingSet::runHashAggrJitChunks( } if (aggregate.inputs.size() != 1) { canRunChunk = false; + skipReason = "input count is not 1 for non-count(*) slot"; break; } @@ -915,6 +1028,7 @@ void GroupingSet::runHashAggrJitChunks( } if (mayPushdown && mayPushdown_[slot.aggregateIndex] && isLazyNotLoaded(*arg)) { canRunChunk = false; + skipReason = "lazy input with pushdown enabled"; break; } hashAggrJitInputVectors_[slotIndex] = arg; @@ -925,6 +1039,9 @@ void GroupingSet::runHashAggrJitChunks( } if (!canRunChunk) { + LOG(INFO) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " + << hashAggrJitChunkDebugString(chunk, aggregates_) + << " reason=" << skipReason; continue; } @@ -934,6 +1051,8 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitNewGroups_[i] = groups[newGroups[i]]; } chunk.init(hashAggrJitNewGroups_.data(), newGroups.size()); + LOG(INFO) << "HashAggrJit initialized new groups for chunk " + << chunk.functionName() << " newGroups=" << newGroups.size(); } chunk.addDense( @@ -941,8 +1060,15 @@ void GroupingSet::runHashAggrJitChunks( activeRows_.end(), hashAggrJitDecodedPtrs_.data(), inputsMayHaveNulls); + LOG(INFO) << "HashAggrJit add executed: chunk=" << chunk.functionName() + << " rows=" << activeRows_.end() + << " inputsMayHaveNulls=" << inputsMayHaveNulls + << " slots=" << hashAggrJitChunkDebugString(chunk, aggregates_); for (const auto& slot : chunk.slots()) { jitExecuted[slot.aggregateIndex] = 1; + LOG(INFO) << "HashAggrJit slot executed in add path: " + << hashAggrJitSlotDebugString( + slot, &aggregates_[slot.aggregateIndex]); } } } @@ -954,23 +1080,30 @@ void GroupingSet::runHashAggrJitExtractChunks( std::vector& jitExtracted) { if (hashAggrJitChunks_.empty() || groups.empty() || hasSpilled() || supportRowBasedOutput_) { + LOG(INFO) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() + << " groups=" << groups.size() << " hasSpilled=" << hasSpilled() + << " supportRowBasedOutput=" << supportRowBasedOutput_; return; } jitExtracted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { if (!chunk.canExtract()) { + LOG(INFO) << "HashAggrJit chunk cannot extract, fallback to non-JIT extract: " + << hashAggrJitChunkDebugString(chunk, aggregates_); continue; } const auto numSlots = chunk.slots().size(); hashAggrJitResultPtrs_.assign(numSlots, nullptr); bool canRunChunk = true; + std::string skipReason; for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { const auto& slot = chunk.slots()[slotIndex]; const auto& aggregate = aggregates_[slot.aggregateIndex]; if (aggregate.distinct || aggregate.mask.has_value() || !aggregate.sortingKeys.empty()) { canRunChunk = false; + skipReason = "distinct/mask/sortingKeys not supported"; break; } auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); @@ -980,16 +1113,26 @@ void GroupingSet::runHashAggrJitExtractChunks( : VectorEncoding::Simple::FLAT; if (aggregateVector->encoding() != expectedEncoding) { canRunChunk = false; + skipReason = "unexpected result vector encoding"; break; } hashAggrJitResultPtrs_[slotIndex] = reinterpret_cast(aggregateVector.get()); } if (!canRunChunk) { + LOG(INFO) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " + << hashAggrJitChunkDebugString(chunk, aggregates_) + << " reason=" << skipReason; continue; } chunk.extract(groups.data(), groups.size(), hashAggrJitResultPtrs_.data()); + LOG(INFO) << "HashAggrJit extract executed: chunk=" << chunk.functionName() + << " groups=" << groups.size() + << " slots=" << hashAggrJitChunkDebugString(chunk, aggregates_); for (const auto& slot : chunk.slots()) { jitExtracted[slot.aggregateIndex] = 1; + LOG(INFO) << "HashAggrJit slot executed in extract path: " + << hashAggrJitSlotDebugString( + slot, &aggregates_[slot.aggregateIndex]); } } } From 4be72fe119e0bb4abfc926201c98e25b4438a420 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 4 Jun 2026 14:06:10 +0800 Subject: [PATCH 11/98] fix for AggregateCompanionAdapter --- bolt/exec/Aggregate.cpp | 22 +++++++++++----------- bolt/exec/AggregateCompanionAdapter.cpp | 14 ++++++++++++++ bolt/exec/AggregateCompanionAdapter.h | 8 ++++++++ 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index 8de8ccc0f..7ca9137d2 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -336,17 +336,17 @@ jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( int32_t aggregateIndex, const jit::HashAggrJitDescriptor& descriptor) const { return jit::HashAggrJitSlot{ - aggregateIndex, - descriptor.kind, - descriptor.inputKind, - descriptor.accumulatorKind, - accumulatorOffset(), - accumulatorNullByte(), - accumulatorNullMask(), - descriptor.countStar, - descriptor.mergeInput, - descriptor.decimal, - descriptor.ops}; + .aggregateIndex = aggregateIndex, + .kind = descriptor.kind, + .inputKind = descriptor.inputKind, + .accumulatorKind = descriptor.accumulatorKind, + .offset = accumulatorOffset(), + .nullByte = accumulatorNullByte(), + .nullMask = accumulatorNullMask(), + .countStar = descriptor.countStar, + .mergeInput = descriptor.mergeInput, + .decimal = descriptor.decimal, + .ops = descriptor.ops}; } #endif diff --git a/bolt/exec/AggregateCompanionAdapter.cpp b/bolt/exec/AggregateCompanionAdapter.cpp index ab082babf..77a468ad1 100644 --- a/bolt/exec/AggregateCompanionAdapter.cpp +++ b/bolt/exec/AggregateCompanionAdapter.cpp @@ -42,6 +42,7 @@ void AggregateCompanionFunctionBase::setOffsetsInternal( int32_t nullByte, uint8_t nullMask, int32_t rowSizeOffset) { + Aggregate::setOffsetsInternal(offset, nullByte, nullMask, rowSizeOffset); fn_->setOffsets(offset, nullByte, nullMask, rowSizeOffset); } @@ -65,6 +66,19 @@ bool AggregateCompanionFunctionBase::supportsToIntermediate() const { return fn_->supportsToIntermediate(); } +#ifdef ENABLE_BOLT_JIT +bool AggregateCompanionFunctionBase::supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const { + return fn_->supportsHashAggrJit(context); +} + +std::optional +AggregateCompanionFunctionBase::createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const { + return fn_->createHashAggrJitDescriptor(context); +} +#endif + bool AggregateCompanionFunctionBase::supportAccumulatorSerde() const { return fn_->supportAccumulatorSerde(); } diff --git a/bolt/exec/AggregateCompanionAdapter.h b/bolt/exec/AggregateCompanionAdapter.h index 57586c8b7..5fbec605c 100644 --- a/bolt/exec/AggregateCompanionAdapter.h +++ b/bolt/exec/AggregateCompanionAdapter.h @@ -52,6 +52,14 @@ class AggregateCompanionFunctionBase : public Aggregate { bool supportsToIntermediate() const override final; +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override final; + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override final; +#endif + bool supportAccumulatorSerde() const override final; uint32_t getAccumulatorSerializeSize(char* group) const override final; From aaa2ed6c81e041379a9ae903315f4f6b63bc469d Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 4 Jun 2026 14:42:44 +0800 Subject: [PATCH 12/98] fix for AggregateCompanionAdapter --- bolt/exec/AggregateCompanionAdapter.cpp | 48 ++++++++++++++++++++++--- bolt/exec/AggregateCompanionAdapter.h | 23 ++++++++++-- 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/bolt/exec/AggregateCompanionAdapter.cpp b/bolt/exec/AggregateCompanionAdapter.cpp index 77a468ad1..f032c7c2b 100644 --- a/bolt/exec/AggregateCompanionAdapter.cpp +++ b/bolt/exec/AggregateCompanionAdapter.cpp @@ -35,8 +35,22 @@ #include "bolt/exec/RowContainer.h" #include "bolt/expression/SignatureBinder.h" #include "bolt/functions/lib/aggregates/AggregateToIntermediate.h" + namespace bytedance::bolt::exec { +namespace { + +#ifdef ENABLE_BOLT_JIT +jit::HashAggrJitPlanContext toUnderlyingMergeContext( + const jit::HashAggrJitPlanContext& context) { + auto adapted = context; + adapted.isRawInput = false; + return adapted; +} +#endif + +} // namespace + void AggregateCompanionFunctionBase::setOffsetsInternal( int32_t offset, int32_t nullByte, @@ -68,14 +82,14 @@ bool AggregateCompanionFunctionBase::supportsToIntermediate() const { #ifdef ENABLE_BOLT_JIT bool AggregateCompanionFunctionBase::supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const { - return fn_->supportsHashAggrJit(context); + const jit::HashAggrJitPlanContext& /*context*/) const { + return false; } std::optional AggregateCompanionFunctionBase::createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const { - return fn_->createHashAggrJitDescriptor(context); + const jit::HashAggrJitPlanContext& /*context*/) const { + return std::nullopt; } #endif @@ -180,6 +194,19 @@ void AggregateCompanionAdapter::PartialFunction::extractValues( fn_->extractAccumulators(groups, numGroups, result); } +#ifdef ENABLE_BOLT_JIT +bool AggregateCompanionAdapter::PartialFunction::supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const { + return fn_->supportsHashAggrJit(context); +} + +std::optional +AggregateCompanionAdapter::PartialFunction::createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const { + return fn_->createHashAggrJitDescriptor(context); +} +#endif + void AggregateCompanionAdapter::MergeFunction::addRawInput( char** groups, const SelectivityVector& rows, @@ -198,6 +225,19 @@ void AggregateCompanionAdapter::MergeFunction::addSingleGroupRawInput( fn_->addSingleGroupIntermediateResults(group, rows, args, mayPushdown); } +#ifdef ENABLE_BOLT_JIT +bool AggregateCompanionAdapter::MergeFunction::supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const { + return fn_->supportsHashAggrJit(toUnderlyingMergeContext(context)); +} + +std::optional +AggregateCompanionAdapter::MergeFunction::createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const { + return fn_->createHashAggrJitDescriptor(toUnderlyingMergeContext(context)); +} +#endif + void AggregateCompanionAdapter::MergeFunction::toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/exec/AggregateCompanionAdapter.h b/bolt/exec/AggregateCompanionAdapter.h index 5fbec605c..e23b3c35d 100644 --- a/bolt/exec/AggregateCompanionAdapter.h +++ b/bolt/exec/AggregateCompanionAdapter.h @@ -33,6 +33,7 @@ #include "bolt/common/memory/HashStringAllocator.h" #include "bolt/exec/Aggregate.h" #include "bolt/expression/VectorFunction.h" + namespace bytedance::bolt::exec { class AggregateCompanionFunctionBase : public Aggregate { @@ -54,10 +55,10 @@ class AggregateCompanionFunctionBase : public Aggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const override final; + const jit::HashAggrJitPlanContext& context) const override; std::optional createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const override final; + const jit::HashAggrJitPlanContext& context) const override; #endif bool supportAccumulatorSerde() const override final; @@ -132,6 +133,14 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : AggregateCompanionFunctionBase{std::move(fn), resultType} {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override; + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override; +#endif + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; @@ -147,7 +156,7 @@ struct AggregateCompanionAdapter { void toIntermediate( const SelectivityVector& rows, std::vector& args, - VectorPtr& result) const override final; + VectorPtr& result) const final; void addRawInput( char** groups, @@ -161,6 +170,14 @@ struct AggregateCompanionAdapter { const std::vector& args, bool mayPushdown) override; +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override; + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override; +#endif + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; From 93ea366437b60362f9b05e35141b8d3aa188e755 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 4 Jun 2026 14:57:56 +0800 Subject: [PATCH 13/98] add debug logs for llmv ir --- bolt/jit/aggregation/HashAggrJit.cpp | 41 +++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index ad1ec80e3..ae0c9d968 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -2,9 +2,11 @@ #include "bolt/jit/aggregation/HashAggrJit.h" +#include #include #include #include +#include #include #include @@ -17,6 +19,32 @@ extern "C" { namespace { +void logHashAggrJitFunctionIR( + const llvm::Module& module, + const std::string& moduleKey, + llvm::StringRef functionName, + llvm::StringRef stage, + bool hasError) { + if (!VLOG_IS_ON(1)) { + return; + } + const auto* function = module.getFunction(functionName); + if (function == nullptr) { + VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + << " stage=" << stage.str() << " function=" << functionName.str() + << " error=" << hasError << ": "; + return; + } + std::string ir; + llvm::raw_string_ostream out(ir); + function->print(out); + out.flush(); + VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + << " stage=" << stage.str() << " function=" << functionName.str() + << " error=" << hasError << ":\n" + << ir; +} + struct JitDecimalSumState { bytedance::bolt::int128_t sum{0}; int64_t overflow{0}; @@ -752,10 +780,21 @@ bool HashAggrJitChunk::codegen() { const auto extractFn = extractFunctionName(); module_ = jit->CompileModule( [&](llvm::Module& module) { - return genInitIR(module, initFn, slots_) || + const bool hasError = genInitIR(module, initFn, slots_) || genAddDenseIR(module, addFn, slots_, true) || genAddDenseIR(module, addNoNullFn, slots_, false) || genExtractIR(module, extractFn, slots_, partialOutput_); + logHashAggrJitFunctionIR(module, moduleKey, initFn, "init", hasError); + logHashAggrJitFunctionIR(module, moduleKey, addFn, "add_dense", hasError); + logHashAggrJitFunctionIR( + module, + moduleKey, + addNoNullFn, + "add_dense_no_null", + hasError); + logHashAggrJitFunctionIR( + module, moduleKey, extractFn, "extract", hasError); + return hasError; }, moduleKey); if (!module_) { From 3485ac7fde5df0f3a787e8582961ac0a6b63a9c6 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 4 Jun 2026 15:35:17 +0800 Subject: [PATCH 14/98] support decimal input for avg_partial and sum_partial --- .../lib/aggregates/AverageAggregateBase.h | 22 ++++++++++++++---- .../lib/aggregates/SumAggregateBase.h | 23 +++++++++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index 8ced96040..5bd6e02ef 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -109,8 +109,11 @@ class AverageAggregateBase : public exec::Aggregate { return false; } if (context.isRawInput) { - return context.inputType->isDecimal() || - jit::isHashAggrJitSupportedType(context.inputType->kind()) || + if (context.inputType->isDecimal()) { + return context.inputType->isShortDecimal() || + context.inputType->isLongDecimal(); + } + return jit::isHashAggrJitSupportedType(context.inputType->kind()) || context.inputType->kind() == TypeKind::HUGEINT; } return context.inputType->isRow() && context.inputType->size() == 2 && @@ -136,9 +139,18 @@ class AverageAggregateBase : public exec::Aggregate { } const bool decimal = context.inputType->isDecimal(); - auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); - if (!inputKind.has_value()) { - return std::nullopt; + std::optional inputKind; + if (decimal) { + inputKind = context.inputType->isShortDecimal() + ? std::optional< + jit::HashAggrJitValueKind>{jit::HashAggrJitValueKind::Int64} + : std::optional{ + jit::HashAggrJitValueKind::Int128}; + } else { + inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } } return jit::HashAggrJitDescriptor{ jit::HashAggrJitKind::Avg, diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index 418d6db8f..bbb10af43 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -72,8 +72,12 @@ class SumAggregateBase if (context.inputType->isRow()) { return false; } - return context.inputType->isDecimal() || - jit::isHashAggrJitSupportedType(context.inputType->kind()) || + if (context.inputType->isDecimal()) { + return context.isRawInput && + (context.inputType->isShortDecimal() || + context.inputType->isLongDecimal()); + } + return jit::isHashAggrJitSupportedType(context.inputType->kind()) || context.inputType->kind() == TypeKind::HUGEINT; } @@ -84,9 +88,18 @@ class SumAggregateBase } const bool decimal = context.isRawInput && context.inputType->isDecimal(); - auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); - if (!inputKind.has_value()) { - return std::nullopt; + std::optional inputKind; + if (decimal) { + inputKind = context.inputType->isShortDecimal() + ? std::optional< + jit::HashAggrJitValueKind>{jit::HashAggrJitValueKind::Int64} + : std::optional{ + jit::HashAggrJitValueKind::Int128}; + } else { + inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } } auto accumulatorKind = decimal From 6c25ffb656c37c6cd34c5bfab602acaedfcbb3a7 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 5 Jun 2026 15:54:28 +0800 Subject: [PATCH 15/98] push down jit implements from xxxbase to xxx --- bolt/exec/RowContainer.cpp | 28 ++ .../lib/aggregates/AverageAggregateBase.h | 203 ------------ .../lib/aggregates/SumAggregateBase.h | 153 ---------- .../sparksql/aggregates/AverageAggregate.cpp | 289 ++++++++++++++++++ .../sparksql/aggregates/DecimalSumAggregate.h | 143 +++++++++ .../sparksql/aggregates/SumAggregate.cpp | 107 ++++++- bolt/jit/aggregation/HashAggrJit.cpp | 107 ++++++- bolt/jit/aggregation/HashAggrJit.h | 8 +- 8 files changed, 675 insertions(+), 363 deletions(-) diff --git a/bolt/exec/RowContainer.cpp b/bolt/exec/RowContainer.cpp index f887a4479..12c96caa1 100644 --- a/bolt/exec/RowContainer.cpp +++ b/bolt/exec/RowContainer.cpp @@ -1756,6 +1756,16 @@ jit_GetDecodedRowFieldDouble(char* vec, int32_t index, int32_t field) { return fieldDecoded.valueAt(decoded->index(index)); } +__attribute__((__visibility__("default"))) int8_t jit_GetDecodedRowFieldI8( + char* vec, + int32_t index, + int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + __attribute__((__visibility__("default"))) int64_t jit_GetDecodedRowFieldI64( char* vec, int32_t index, @@ -1765,6 +1775,24 @@ __attribute__((__visibility__("default"))) int64_t jit_GetDecodedRowFieldI64( bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); return fieldDecoded.valueAt(decoded->index(index)); } + +__attribute__((__visibility__("default"))) bytedance::bolt::int128_t +jit_GetDecodedRowFieldI128(char* vec, int32_t index, int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.valueAt(decoded->index(index)); +} + +__attribute__((__visibility__("default"))) int8_t jit_GetDecodedRowFieldIsNull( + char* vec, + int32_t index, + int32_t field) { + auto* decoded = reinterpret_cast(vec); + auto* rowVector = decoded->base()->as(); + bytedance::bolt::DecodedVector fieldDecoded(*rowVector->childAt(field)); + return fieldDecoded.isNullAt(decoded->index(index)); +} // get decoded value string __attribute__((__visibility__("default"))) const char* jit_GetDecodedValueStringView(char* vec, int32_t index) { diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index 5bd6e02ef..f8869255f 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -102,209 +102,6 @@ class AverageAggregateBase : public exec::Aggregate { return true; } -#ifdef ENABLE_BOLT_JIT - bool supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const override { - if (context.inputCount != 1 || !context.inputType) { - return false; - } - if (context.isRawInput) { - if (context.inputType->isDecimal()) { - return context.inputType->isShortDecimal() || - context.inputType->isLongDecimal(); - } - return jit::isHashAggrJitSupportedType(context.inputType->kind()) || - context.inputType->kind() == TypeKind::HUGEINT; - } - return context.inputType->isRow() && context.inputType->size() == 2 && - context.inputType->childAt(1)->kind() == TypeKind::BIGINT && - context.inputType->childAt(0)->kind() == TypeKind::DOUBLE; - } - - std::optional createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const override { - if (!supportsHashAggrJit(context)) { - return std::nullopt; - } - - if (!context.isRawInput) { - return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Avg, - jit::HashAggrJitValueKind::Double, - jit::HashAggrJitValueKind::Double, - false, - true, - false, - hashAggrJitOps()}; - } - - const bool decimal = context.inputType->isDecimal(); - std::optional inputKind; - if (decimal) { - inputKind = context.inputType->isShortDecimal() - ? std::optional< - jit::HashAggrJitValueKind>{jit::HashAggrJitValueKind::Int64} - : std::optional{ - jit::HashAggrJitValueKind::Int128}; - } else { - inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); - if (!inputKind.has_value()) { - return std::nullopt; - } - } - return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Avg, - *inputKind, - decimal ? jit::HashAggrJitValueKind::Int128 - : jit::HashAggrJitValueKind::Double, - false, - false, - decimal, - hashAggrJitOps()}; - } - - private: - static void compileHashAggrJitCreate( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.setAccumulatorNull(group, slot); - if (slot.decimal) { - codegen.builder().CreateCall( - codegen.module().getFunction("jit_HashAggrInitDecimalAvg"), - {group, codegen.builder().getInt32(slot.offset)}); - return; - } - codegen.storeValue( - group, - codegen.llvmType(slot.accumulatorKind), - slot.offset, - llvm::ConstantFP::get(codegen.llvmType(slot.accumulatorKind), 0.0)); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().getInt64(0)); - } - - static void compileHashAggrJitAdd( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - if (slot.mergeInput) { - codegen.clearAccumulatorNull(group, slot); - auto* sum = codegen.loadAvgMergeField( - decoded, row, 0, codegen.builder().getDoubleTy()); - auto* count = codegen.loadAvgMergeField( - decoded, row, 1, codegen.builder().getInt64Ty()); - auto* oldSum = codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); - codegen.storeValue( - group, - codegen.builder().getDoubleTy(), - slot.offset, - codegen.builder().CreateFAdd(oldSum, sum)); - auto* oldCount = - codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset + 8); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().CreateAdd(oldCount, count)); - return; - } - - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); - if (slot.decimal) { - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalAvgI128" - : "jit_HashAggrUpdateDecimalAvgI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - rawValue, - slot.inputKind, - jit::HashAggrJitValueKind::Int128) - : rawValue}); - return; - } - - auto* value = codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); - codegen.clearAccumulatorNull(group, slot); - auto* oldSum = codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); - codegen.storeValue( - group, - codegen.llvmType(slot.accumulatorKind), - slot.offset, - codegen.builder().CreateFAdd(oldSum, value)); - auto* oldCount = - codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset + 8); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot& slot, - bool partialOutput) { - // if (slot.decimal || slot.accumulatorKind == - // jit::HashAggrJitValueKind::Int128) { - // return false; - // } - // return !partialOutput || slot.accumulatorKind == - // jit::HashAggrJitValueKind::Double; - return false; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - auto* sum = - codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); - auto* count = - codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset + 8); - auto* isNull = codegen.builder().CreateZExt( - codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); - if (target.partialOutput) { - codegen.emitPartialAvgResult(target.resultVector, target.row, sum, count, isNull); - return; - } - auto* countIsZero = codegen.builder().CreateICmpEQ(count, codegen.builder().getInt64(0)); - auto* divisor = codegen.builder().CreateSIToFP(count, codegen.llvmType(slot.accumulatorKind)); - auto* value = codegen.builder().CreateFDiv(sum, divisor); - auto* finalIsNull = codegen.builder().CreateZExt(countIsZero, codegen.builder().getInt8Ty()); - codegen.emitFlatValue( - target.resultVector, - target.row, - slot.accumulatorKind, - value, - finalIsNull); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "avg", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } - - public: -#endif - FLATTEN void toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index bbb10af43..acafb3766 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -63,159 +63,6 @@ class SumAggregateBase return true; } -#ifdef ENABLE_BOLT_JIT - bool supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const override { - if (context.inputCount != 1 || !context.inputType) { - return false; - } - if (context.inputType->isRow()) { - return false; - } - if (context.inputType->isDecimal()) { - return context.isRawInput && - (context.inputType->isShortDecimal() || - context.inputType->isLongDecimal()); - } - return jit::isHashAggrJitSupportedType(context.inputType->kind()) || - context.inputType->kind() == TypeKind::HUGEINT; - } - - std::optional createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const override { - if (!supportsHashAggrJit(context)) { - return std::nullopt; - } - - const bool decimal = context.isRawInput && context.inputType->isDecimal(); - std::optional inputKind; - if (decimal) { - inputKind = context.inputType->isShortDecimal() - ? std::optional< - jit::HashAggrJitValueKind>{jit::HashAggrJitValueKind::Int64} - : std::optional{ - jit::HashAggrJitValueKind::Int128}; - } else { - inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); - if (!inputKind.has_value()) { - return std::nullopt; - } - } - - auto accumulatorKind = decimal - ? jit::HashAggrJitValueKind::Int128 - : ((*inputKind == jit::HashAggrJitValueKind::Float || - *inputKind == jit::HashAggrJitValueKind::Double) - ? jit::HashAggrJitValueKind::Double - : jit::HashAggrJitValueKind::Int64); - - return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Sum, - *inputKind, - accumulatorKind, - false, - !context.isRawInput, - decimal, - hashAggrJitOps()}; - } - - private: - static void compileHashAggrJitCreate( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - if (slot.decimal) { - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateCall( - codegen.module().getFunction( - slot.kind == jit::HashAggrJitKind::Sum - ? "jit_HashAggrInitDecimalSum" - : "jit_HashAggrInitDecimalAvg"), - {group, codegen.builder().getInt32(slot.offset)}); - return; - } - codegen.setAccumulatorNull(group, slot); - auto* accType = codegen.llvmType(slot.accumulatorKind); - if (codegen.isFloatKind(slot.accumulatorKind)) { - codegen.storeValue( - group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); - } else { - codegen.storeValue( - group, accType, slot.offset, llvm::ConstantInt::get(accType, 0)); - } - } - - static void compileHashAggrJitAdd( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); - if (slot.decimal) { - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalSumI128" - : "jit_HashAggrUpdateDecimalSumI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - rawValue, - slot.inputKind, - jit::HashAggrJitValueKind::Int128) - : rawValue}); - return; - } - auto* value = - codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); - auto* accType = codegen.llvmType(slot.accumulatorKind); - codegen.clearAccumulatorNull(group, slot); - auto* oldValue = codegen.loadValue(group, accType, slot.offset); - auto* newValue = codegen.isFloatKind(slot.accumulatorKind) - ? codegen.builder().CreateFAdd(oldValue, value) - : codegen.builder().CreateAdd(oldValue, value); - codegen.storeValue(group, accType, slot.offset, newValue); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot& slot, - bool) { - // return !slot.decimal && - // slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; - return false; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - auto* value = codegen.loadValue( - group, codegen.llvmType(slot.accumulatorKind), slot.offset); - auto* isNull = codegen.builder().CreateZExt( - codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); - codegen.emitFlatValue( - target.resultVector, target.row, slot.accumulatorKind, value, isNull); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "sum", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } - - public: -#endif - void toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index f1225b0f1..662b73ffe 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -42,6 +42,149 @@ class AverageAggregate explicit AverageAggregate(TypePtr resultType) : AverageAggregateBase(resultType) {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.inputCount != 1 || !context.inputType) { + return false; + } + if (context.isRawInput) { + if (context.inputType->isDecimal()) { + return false; + } + return jit::isHashAggrJitSupportedType(context.inputType->kind()) || + context.inputType->kind() == TypeKind::HUGEINT; + } + return context.inputType->isRow() && context.inputType->size() == 2 && + context.inputType->childAt(1)->kind() == TypeKind::BIGINT && + context.inputType->childAt(0)->kind() == TypeKind::DOUBLE; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + + if (!context.isRawInput) { + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Avg, + jit::HashAggrJitValueKind::Double, + jit::HashAggrJitValueKind::Double, + false, + true, + false, + hashAggrJitOps()}; + } + + auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Avg, + *inputKind, + jit::HashAggrJitValueKind::Double, + false, + false, + false, + hashAggrJitOps()}; + } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.storeValue( + group, + codegen.llvmType(slot.accumulatorKind), + slot.offset, + llvm::ConstantFP::get(codegen.llvmType(slot.accumulatorKind), 0.0)); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().getInt64(0)); + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + if (slot.mergeInput) { + codegen.clearAccumulatorNull(group, slot); + auto* sum = codegen.loadDecodedRowField( + decoded, row, 0, jit::HashAggrJitValueKind::Double); + auto* count = codegen.loadDecodedRowField( + decoded, row, 1, jit::HashAggrJitValueKind::Int64); + auto* oldSum = + codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getDoubleTy(), + slot.offset, + codegen.builder().CreateFAdd(oldSum, sum)); + auto* oldCount = codegen.loadValue( + group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, count)); + return; + } + + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* value = + codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldSum = + codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + codegen.storeValue( + group, + codegen.llvmType(slot.accumulatorKind), + slot.offset, + codegen.builder().CreateFAdd(oldSum, value)); + auto* oldCount = codegen.loadValue( + group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot&, + bool) { + return false; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen&, + llvm::Value*, + const jit::HashAggrJitSlot&, + const jit::HashAggrJitExtractTarget&) {} + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "avg", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } + + public: +#endif + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) override { auto rowVector = (*result)->as(); @@ -93,6 +236,41 @@ class DecimalAverageAggregate : public DecimalAggregate { explicit DecimalAverageAggregate(TypePtr resultType, TypePtr sumType) : DecimalAggregate(resultType), sumType_(sumType) {} +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.inputCount != 1 || !context.inputType) { + return false; + } + if (context.isRawInput) { + return context.inputType->isDecimal() && + (context.inputType->isShortDecimal() || + context.inputType->isLongDecimal()); + } + return context.inputType->isRow() && context.inputType->size() == 2 && + context.inputType->childAt(0)->isDecimal() && + context.inputType->childAt(1)->kind() == TypeKind::BIGINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + const auto& valueType = + context.isRawInput ? context.inputType : context.inputType->childAt(0); + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Avg, + valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 + : jit::HashAggrJitValueKind::Int128, + jit::HashAggrJitValueKind::Int128, + false, + !context.isRawInput, + true, + hashAggrJitOps()}; + } +#endif + void addIntermediateResults( char** groups, const SelectivityVector& rows, @@ -333,6 +511,117 @@ class DecimalAverageAggregate : public DecimalAggregate { } private: +#ifdef ENABLE_BOLT_JIT + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateCall( + codegen.module().getFunction("jit_HashAggrInitDecimalAvg"), + {group, codegen.builder().getInt32(slot.offset)}); + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock* nextBlock) { + if (slot.mergeInput) { + auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge", + function, + continueBlock); + auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); + auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); + auto* count = codegen.loadDecodedRowField( + decoded, row, 1, jit::HashAggrJitValueKind::Int64); + auto* countPositive = codegen.builder().CreateICmpSGT( + count, codegen.builder().getInt64(0)); + auto* isOverflow = codegen.builder().CreateAnd( + sumIsNull, + codegen.builder().CreateAnd( + codegen.builder().CreateNot(countIsNull), countPositive)); + codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + codegen.builder().SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(mergeBlock); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrMergeDecimalAvgI128" + : "jit_HashAggrMergeDecimalAvgI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) + : sum, + count}); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(continueBlock); + return; + } + + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalAvgI128" + : "jit_HashAggrUpdateDecimalAvgI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + rawValue, slot.inputKind, jit::HashAggrJitValueKind::Int128) + : rawValue}); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot&, + bool) { + return false; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen&, + llvm::Value*, + const jit::HashAggrJitSlot&, + const jit::HashAggrJitExtractTarget&) {} + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "avg_decimal", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } +#endif + template inline void mergeSumCount( LongDecimalWithOverflowState* accumulator, diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index d8d8d5e3d..d2dd73324 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -61,6 +61,149 @@ class DecimalSumAggregate : public exec::Aggregate { return alignof(DecimalSum); } +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.inputCount != 1 || !context.inputType) { + return false; + } + if (context.isRawInput) { + return context.inputType->isDecimal() && + (context.inputType->isShortDecimal() || + context.inputType->isLongDecimal()); + } + return context.inputType->isRow() && context.inputType->size() == 2 && + context.inputType->childAt(0)->isDecimal() && + context.inputType->childAt(1)->kind() == TypeKind::BOOLEAN; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + const auto& valueType = + context.isRawInput ? context.inputType : context.inputType->childAt(0); + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Sum, + valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 + : jit::HashAggrJitValueKind::Int128, + jit::HashAggrJitValueKind::Int128, + false, + !context.isRawInput, + true, + hashAggrJitOps()}; + } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateCall( + codegen.module().getFunction("jit_HashAggrInitDecimalSum"), + {group, codegen.builder().getInt32(slot.offset)}); + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock* nextBlock) { + if (slot.mergeInput) { + auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge", + function, + continueBlock); + auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); + auto* isEmpty = codegen.loadDecodedRowField( + decoded, row, 1, jit::HashAggrJitValueKind::Int8); + auto* isNotEmpty = + codegen.builder().CreateICmpEQ(isEmpty, codegen.builder().getInt8(0)); + auto* isOverflow = codegen.builder().CreateAnd(sumIsNull, isNotEmpty); + codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + codegen.builder().SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(mergeBlock); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrMergeDecimalSumI128" + : "jit_HashAggrMergeDecimalSumI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) + : sum, + isEmpty}); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(continueBlock); + return; + } + + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalSumI128" + : "jit_HashAggrUpdateDecimalSumI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + rawValue, slot.inputKind, jit::HashAggrJitValueKind::Int128) + : rawValue}); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot&, + bool) { + return false; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen&, + llvm::Value*, + const jit::HashAggrJitSlot&, + const jit::HashAggrJitExtractTarget&) {} + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "sum_decimal", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } + + public: +#endif + void initializeNewGroups( char** groups, folly::Range indices) override { diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index cd852df3f..622cb0e6e 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -37,7 +37,112 @@ namespace bytedance::bolt::functions::aggregate::sparksql { namespace { template -using SumAggregate = SumAggregateBase; +class SumAggregate : public SumAggregateBase { + public: + explicit SumAggregate(TypePtr resultType) + : SumAggregateBase(resultType) {} + +#ifdef ENABLE_BOLT_JIT + bool supportsHashAggrJit( + const jit::HashAggrJitPlanContext& context) const override { + if (context.inputCount != 1 || !context.inputType) { + return false; + } + if (context.inputType->isRow() || context.inputType->isDecimal()) { + return false; + } + return jit::isHashAggrJitSupportedType(context.inputType->kind()) || + context.inputType->kind() == TypeKind::HUGEINT; + } + + std::optional createHashAggrJitDescriptor( + const jit::HashAggrJitPlanContext& context) const override { + if (!supportsHashAggrJit(context)) { + return std::nullopt; + } + + auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + if (!inputKind.has_value()) { + return std::nullopt; + } + + const auto accumulatorKind = + (*inputKind == jit::HashAggrJitValueKind::Float || + *inputKind == jit::HashAggrJitValueKind::Double) + ? jit::HashAggrJitValueKind::Double + : jit::HashAggrJitValueKind::Int64; + + return jit::HashAggrJitDescriptor{ + jit::HashAggrJitKind::Sum, + *inputKind, + accumulatorKind, + false, + !context.isRawInput, + false, + hashAggrJitOps()}; + } + + private: + static void compileHashAggrJitCreate( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + auto* accType = codegen.llvmType(slot.accumulatorKind); + if (codegen.isFloatKind(slot.accumulatorKind)) { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); + } else { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantInt::get(accType, 0)); + } + } + + static void compileHashAggrJitAdd( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* value = + codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); + auto* accType = codegen.llvmType(slot.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldValue = codegen.loadValue(group, accType, slot.offset); + auto* newValue = codegen.isFloatKind(slot.accumulatorKind) + ? codegen.builder().CreateFAdd(oldValue, value) + : codegen.builder().CreateAdd(oldValue, value); + codegen.storeValue(group, accType, slot.offset, newValue); + } + + static bool canCompileHashAggrJitExtract( + const jit::HashAggrJitSlot&, + bool) { + return false; + } + + static void compileHashAggrJitExtract( + jit::HashAggrJitCodegen&, + llvm::Value*, + const jit::HashAggrJitSlot&, + const jit::HashAggrJitExtractTarget&) {} + + static const jit::HashAggrJitOps* hashAggrJitOps() { + static const jit::HashAggrJitOps kOps{ + "sum", + &compileHashAggrJitCreate, + &compileHashAggrJitAdd, + &canCompileHashAggrJitExtract, + &compileHashAggrJitExtract}; + return &kOps; + } + + public: +#endif +}; TypePtr getDecimalSumType( const TypePtr& resultType, diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index ae0c9d968..12a68ed67 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -13,6 +13,7 @@ #include +#include "bolt/common/base/Exceptions.h" #include "bolt/jit/ThrustJITv2.h" extern "C" { @@ -123,6 +124,48 @@ __attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalAvgI128 ++state->count; } +__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalSumI64( + char* group, + int32_t offset, + int64_t value, + int8_t isEmpty) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow( + state->sum, static_cast(value), state->sum); + state->isEmpty = state->isEmpty && static_cast(isEmpty); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalSumI128( + char* group, + int32_t offset, + bytedance::bolt::int128_t value, + int8_t isEmpty) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); + state->isEmpty = state->isEmpty && static_cast(isEmpty); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalAvgI64( + char* group, + int32_t offset, + int64_t value, + int64_t count) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow( + state->sum, static_cast(value), state->sum); + state->count += count; +} + +__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalAvgI128( + char* group, + int32_t offset, + bytedance::bolt::int128_t value, + int64_t count) { + auto* state = reinterpret_cast(group + offset); + state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); + state->count += count; +} + } // extern "C" namespace bytedance::bolt::jit { @@ -160,8 +203,14 @@ void ensureBuiltinDeclarations(llvm::Module& module) { module, "jit_GetDecodedValueDouble", doubleTy, {i8PtrTy, i32Ty}); declareFunction( module, "jit_GetDecodedRowFieldDouble", doubleTy, {i8PtrTy, i32Ty, i32Ty}); + declareFunction( + module, "jit_GetDecodedRowFieldI8", i8Ty, {i8PtrTy, i32Ty, i32Ty}); declareFunction( module, "jit_GetDecodedRowFieldI64", i64Ty, {i8PtrTy, i32Ty, i32Ty}); + declareFunction( + module, "jit_GetDecodedRowFieldI128", i128Ty, {i8PtrTy, i32Ty, i32Ty}); + declareFunction( + module, "jit_GetDecodedRowFieldIsNull", i8Ty, {i8PtrTy, i32Ty, i32Ty}); declareFunction(module, "jit_GetDecodedIsNull", i8Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_HashAggrInitDecimalSum", voidTy, {i8PtrTy, i32Ty}); declareFunction(module, "jit_HashAggrInitDecimalAvg", voidTy, {i8PtrTy, i32Ty}); @@ -173,6 +222,26 @@ void ensureBuiltinDeclarations(llvm::Module& module) { module, "jit_HashAggrUpdateDecimalAvgI64", voidTy, {i8PtrTy, i32Ty, i64Ty}); declareFunction( module, "jit_HashAggrUpdateDecimalAvgI128", voidTy, {i8PtrTy, i32Ty, i128Ty}); + declareFunction( + module, + "jit_HashAggrMergeDecimalSumI64", + voidTy, + {i8PtrTy, i32Ty, i64Ty, i8Ty}); + declareFunction( + module, + "jit_HashAggrMergeDecimalSumI128", + voidTy, + {i8PtrTy, i32Ty, i128Ty, i8Ty}); + declareFunction( + module, + "jit_HashAggrMergeDecimalAvgI64", + voidTy, + {i8PtrTy, i32Ty, i64Ty, i64Ty}); + declareFunction( + module, + "jit_HashAggrMergeDecimalAvgI128", + voidTy, + {i8PtrTy, i32Ty, i128Ty, i64Ty}); declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); declareFunction(module, "jit_HashAggrSetFlatI8", voidTy, {i8PtrTy, i32Ty, i8Ty, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatI16", voidTy, {i8PtrTy, i32Ty, i16Ty, i8Ty}); @@ -227,6 +296,24 @@ std::string decodedValueFunction(HashAggrJitValueKind kind) { return "jit_GetDecodedValueI64"; } +std::string decodedRowFieldFunction(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Int8: + return "jit_GetDecodedRowFieldI8"; + case HashAggrJitValueKind::Int64: + return "jit_GetDecodedRowFieldI64"; + case HashAggrJitValueKind::Int128: + return "jit_GetDecodedRowFieldI128"; + case HashAggrJitValueKind::Double: + return "jit_GetDecodedRowFieldDouble"; + case HashAggrJitValueKind::Int16: + case HashAggrJitValueKind::Int32: + case HashAggrJitValueKind::Float: + break; + } + return ""; +} + std::string setFlatValueFunction(HashAggrJitValueKind kind); bool isFloatKind(HashAggrJitValueKind kind) { @@ -392,17 +479,29 @@ bool HashAggrJitCodegen::isFloatKind(HashAggrJitValueKind kind) const { return ::bytedance::bolt::jit::isFloatKind(kind); } -llvm::Value* HashAggrJitCodegen::loadAvgMergeField( +llvm::Value* HashAggrJitCodegen::loadDecodedRowField( llvm::Value* decoded, llvm::Value* row, int32_t field, - llvm::Type* type) const { - const char* name = type->isDoubleTy() ? "jit_GetDecodedRowFieldDouble" - : "jit_GetDecodedRowFieldI64"; + HashAggrJitValueKind kind) const { + const auto name = decodedRowFieldFunction(kind); + BOLT_CHECK( + !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); return builder().CreateCall( module_.getFunction(name), {decoded, row, builder().getInt32(field)}); } +llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( + llvm::Value* decoded, + llvm::Value* row, + int32_t field) const { + return builder().CreateICmpNE( + builder().CreateCall( + module_.getFunction("jit_GetDecodedRowFieldIsNull"), + {decoded, row, builder().getInt32(field)}), + builder().getInt8(0)); +} + void HashAggrJitCodegen::emitFlatValue( llvm::Value* vector, llvm::Value* row, diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index d964e3319..7f096c739 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -147,11 +147,15 @@ class HashAggrJitCodegen { HashAggrJitValueKind from, HashAggrJitValueKind to) const; bool isFloatKind(HashAggrJitValueKind kind) const; - llvm::Value* loadAvgMergeField( + llvm::Value* loadDecodedRowField( llvm::Value* decoded, llvm::Value* row, int32_t field, - llvm::Type* type) const; + HashAggrJitValueKind kind) const; + llvm::Value* isDecodedRowFieldNull( + llvm::Value* decoded, + llvm::Value* row, + int32_t field) const; void emitFlatValue( llvm::Value* vector, llvm::Value* row, From c5586660dc1b754361f16cf0903854abba41dccb Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 6 Jun 2026 21:02:09 +0800 Subject: [PATCH 16/98] remove config --- bolt/core/QueryConfig.h | 6 --- bolt/exec/GroupingSet.cpp | 11 +++--- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 1 - .../aggregates/tests/CountAggregationTest.cpp | 38 +++++++++++++++++++ .../prestosql/aggregates/tests/MinMaxTest.cpp | 26 +++++++++++++ .../aggregates/tests/SumAggregationTest.cpp | 4 -- bolt/jit/aggregation/HashAggrJit.cpp | 2 + 7 files changed, 71 insertions(+), 17 deletions(-) diff --git a/bolt/core/QueryConfig.h b/bolt/core/QueryConfig.h index de88fd923..1d96c5ebf 100644 --- a/bolt/core/QueryConfig.h +++ b/bolt/core/QueryConfig.h @@ -657,8 +657,6 @@ class QueryConfig { "hashaggr.jit.min_fuse_width"; static constexpr const char* kHashAggrJitMaxFuseWidth = "hashaggr.jit.max_fuse_width"; - static constexpr const char* kHashAggrJitCompileMinCount = - "hashaggr.jit.compile_min_count"; // expired, to deleted later static constexpr const char* kBoltJitEnabled = "bolt.jit.enabled"; @@ -1626,10 +1624,6 @@ class QueryConfig { return get(kHashAggrJitMaxFuseWidth, 16); } - int32_t hashAggrJitCompileMinCount() const { - return get(kHashAggrJitCompileMinCount, 3); - } - int exceptionTraceLevel() const { return get(kExceptionTraceLevel, 1); } diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index b11262d42..2f29db9d3 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -879,17 +879,16 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { return; } - const auto minFuseWidth = std::max(1, queryConfig_.hashAggrJitMinFuseWidth()); - const auto maxFuseWidth = std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); - const auto compileMinCount = - std::max(1, queryConfig_.hashAggrJitCompileMinCount()); - const auto minChunkWidth = std::max(minFuseWidth, compileMinCount); + const auto minFuseWidth = + std::max(1, queryConfig_.hashAggrJitMinFuseWidth()); + const auto maxFuseWidth = + std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); + const auto minChunkWidth = minFuseWidth; LOG(INFO) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ << " isPartial=" << isPartial_ << " aggregates=" << aggregates_.size() << " minFuseWidth=" << minFuseWidth << " maxFuseWidth=" << maxFuseWidth - << " compileMinCount=" << compileMinCount << " minChunkWidth=" << minChunkWidth; std::vector currentChunkSlots; currentChunkSlots.reserve(maxFuseWidth); diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index f08032696..2e9a98cf3 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -229,7 +229,6 @@ class HashAggrJitBenchmark : public VectorTestBase { .config(core::QueryConfig::kHashAggrJitEnabled, enableJit ? "true" : "false") .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "4") .config(core::QueryConfig::kHashAggrJitMaxFuseWidth, "16") - .config(core::QueryConfig::kHashAggrJitCompileMinCount, "3") .copyResults(pool_.get()); } diff --git a/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp b/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp index 78da1704e..bee6cd1a4 100644 --- a/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp @@ -94,6 +94,44 @@ TEST_F(CountAggregationTest, count) { "SELECT c0 % 10, count(c7) FROM tmp GROUP BY 1"); } +TEST_F(CountAggregationTest, hashAggrJitBooleanCount) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 11; }), + makeFlatVector( + 256, + [](auto row) { return row % 3 == 0; }, + [](auto row) { return row % 7 == 0; })}); + + auto singlePlan = PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"count(c1)"}) + .planNode(); + + auto singleNoJit = AssertQueryBuilder(singlePlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto singleJit = AssertQueryBuilder(singlePlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({singleNoJit}, {singleJit}); + + auto partialFinalPlan = PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"count(c1)"}) + .finalAggregation() + .planNode(); + + auto partialFinalNoJit = AssertQueryBuilder(partialFinalPlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto partialFinalJit = AssertQueryBuilder(partialFinalPlan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({partialFinalNoJit}, {partialFinalJit}); +} + TEST_F(CountAggregationTest, mask) { std::vector data; // Make batches where some batches have mask all true, some half and half and diff --git a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp index cbc68fe16..b6653b53b 100644 --- a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp @@ -30,6 +30,8 @@ #include #include +#include "bolt/exec/tests/utils/AssertQueryBuilder.h" +#include "bolt/exec/tests/utils/PlanBuilder.h" #include "bolt/common/base/tests/GTestUtils.h" #include "bolt/functions/lib/aggregates/tests/utils/AggregationTestBase.h" #include "bolt/vector/fuzzer/VectorFuzzer.h" @@ -220,6 +222,30 @@ TEST_F(MinMaxTest, minBoolean) { doTest(min, BOOLEAN()); } +TEST_F(MinMaxTest, hashAggrJitBooleanMinMax) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 13; }), + makeFlatVector( + 256, + [](auto row) { return row % 5 < 2; }, + [](auto row) { return row % 11 == 0; })}); + + auto plan = PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + TEST_F(MinMaxTest, constVarchar) { // Create two batches of the source data for the aggregation: // Column c0 with 1K of "apple" and 1K of "banana". diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index a29da74f5..5d89a6cab 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -161,7 +161,6 @@ TEST_F(SumAggregationTest, hashAggrJitDecimalSumAndFloatingMinMax) { auto jit = AssertQueryBuilder(plan) .config(core::QueryConfig::kHashAggrJitEnabled, "true") .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") - .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") .copyResults(pool()); assertEqualResults({noJit}, {jit}); } @@ -190,7 +189,6 @@ TEST_F(SumAggregationTest, hashAggrJitMergeAndExtract) { auto jit = AssertQueryBuilder(plan) .config(core::QueryConfig::kHashAggrJitEnabled, "true") .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") - .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") .copyResults(pool()); assertEqualResults({noJit}, {jit}); } @@ -213,7 +211,6 @@ TEST_F(SumAggregationTest, hashAggrJitPartialAvgExtractAccumulators) { auto jit = AssertQueryBuilder(plan) .config(core::QueryConfig::kHashAggrJitEnabled, "true") .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") - .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") .copyResults(pool()); assertEqualResults({noJit}, {jit}); } @@ -255,7 +252,6 @@ TEST_F(SumAggregationTest, hashAggrJitSplitsContiguousSegments) { auto jit = AssertQueryBuilder(plan) .config(core::QueryConfig::kHashAggrJitEnabled, "true") .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") - .config(core::QueryConfig::kHashAggrJitCompileMinCount, "1") .config(core::QueryConfig::kHashAggrJitMaxFuseWidth, "4") .copyResults(pool()); assertEqualResults({noJit}, {jit}); diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 12a68ed67..fb4b675e3 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -780,6 +780,7 @@ std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { std::optional hashAggrJitValueKind(TypeKind kind) { switch (kind) { + case TypeKind::BOOLEAN: case TypeKind::TINYINT: return HashAggrJitValueKind::Int8; case TypeKind::SMALLINT: @@ -801,6 +802,7 @@ std::optional hashAggrJitValueKind(TypeKind kind) { bool isHashAggrJitSupportedType(TypeKind kind) { switch (kind) { + case TypeKind::BOOLEAN: case TypeKind::TINYINT: case TypeKind::SMALLINT: case TypeKind::INTEGER: From b958315d082e3f765ef43ee2cdd7ad156d38452c Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 6 Jun 2026 22:28:28 +0800 Subject: [PATCH 17/98] support bool input for max/min/count --- .../prestosql/aggregates/tests/MinMaxTest.cpp | 37 ++++++++++++------- bolt/jit/aggregation/HashAggrJit.cpp | 11 ++++++ bolt/jit/aggregation/HashAggrJit.h | 1 + 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp index b6653b53b..67b20660e 100644 --- a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp @@ -230,20 +230,31 @@ TEST_F(MinMaxTest, hashAggrJitBooleanMinMax) { [](auto row) { return row % 5 < 2; }, [](auto row) { return row % 11 == 0; })}); - auto plan = PlanBuilder() - .values({data}) - .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) - .finalAggregation() - .planNode(); - - auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) - .config(core::QueryConfig::kHashAggrJitEnabled, "false") + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") .copyResults(pool()); - auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) - .config(core::QueryConfig::kHashAggrJitEnabled, "true") - .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") - .copyResults(pool()); - assertEqualResults({noJit}, {jit}); + assertEqualResults({noJit}, {jit}); + } } TEST_F(MinMaxTest, constVarchar) { diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index fb4b675e3..a16f8b852 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -192,6 +192,7 @@ void ensureBuiltinDeclarations(llvm::Module& module) { auto* voidTy = llvm::Type::getVoidTy(context); auto* i8PtrTy = llvm::PointerType::get(context, 0); + declareFunction(module, "jit_GetDecodedValueBool", i8Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_GetDecodedValueI8", i8Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_GetDecodedValueI16", i16Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_GetDecodedValueI32", i32Ty, {i8PtrTy, i32Ty}); @@ -258,6 +259,7 @@ void ensureBuiltinDeclarations(llvm::Module& module) { llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { switch (kind) { + case HashAggrJitValueKind::Bool: case HashAggrJitValueKind::Int8: return builder.getInt8Ty(); case HashAggrJitValueKind::Int16: @@ -278,6 +280,8 @@ llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { std::string decodedValueFunction(HashAggrJitValueKind kind) { switch (kind) { + case HashAggrJitValueKind::Bool: + return "jit_GetDecodedValueBool"; case HashAggrJitValueKind::Int8: return "jit_GetDecodedValueI8"; case HashAggrJitValueKind::Int16: @@ -298,6 +302,7 @@ std::string decodedValueFunction(HashAggrJitValueKind kind) { std::string decodedRowFieldFunction(HashAggrJitValueKind kind) { switch (kind) { + case HashAggrJitValueKind::Bool: case HashAggrJitValueKind::Int8: return "jit_GetDecodedRowFieldI8"; case HashAggrJitValueKind::Int64: @@ -676,6 +681,9 @@ std::string setFlatValueFunction(HashAggrJitValueKind kind) { return "jit_HashAggrSetFlatFloat"; case HashAggrJitValueKind::Double: return "jit_HashAggrSetFlatDouble"; + // Bool output vectors are FlatVector, which cannot reuse the int8 + // setter. JIT extract is not yet supported for Bool. + case HashAggrJitValueKind::Bool: case HashAggrJitValueKind::Int128: return ""; } @@ -760,6 +768,8 @@ HashAggrJitChunk::HashAggrJitChunk( std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { switch (kind) { + case HashAggrJitValueKind::Bool: + return "bool"; case HashAggrJitValueKind::Int8: return "i8"; case HashAggrJitValueKind::Int16: @@ -781,6 +791,7 @@ std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { std::optional hashAggrJitValueKind(TypeKind kind) { switch (kind) { case TypeKind::BOOLEAN: + return HashAggrJitValueKind::Bool; case TypeKind::TINYINT: return HashAggrJitValueKind::Int8; case TypeKind::SMALLINT: diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 7f096c739..c7875df09 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -41,6 +41,7 @@ enum class HashAggrJitKind : uint8_t { }; enum class HashAggrJitValueKind : uint8_t { + Bool, Int8, Int16, Int32, From 674df4151a9a74374f7a951928f338d8338a965c Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 6 Jun 2026 22:43:31 +0800 Subject: [PATCH 18/98] support decimal input for max/min/count --- .../prestosql/aggregates/CountAggregate.cpp | 5 +- .../prestosql/aggregates/MinMaxAggregates.cpp | 5 +- .../aggregates/tests/CountAggregationTest.cpp | 45 +++++++++++ .../prestosql/aggregates/tests/MinMaxTest.cpp | 74 +++++++++++++++++++ 4 files changed, 125 insertions(+), 4 deletions(-) diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index d253e44eb..64eea87d7 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -59,8 +59,9 @@ class CountAggregate : public SimpleNumericAggregate { if (context.isRawInput) { return context.inputCount == 0 || (context.inputCount == 1 && context.inputType != nullptr && - !context.inputType->isRow() && !context.inputType->isDecimal() && - jit::isHashAggrJitSupportedType(context.inputType->kind())); + !context.inputType->isRow() && + (context.inputType->isDecimal() || + jit::isHashAggrJitSupportedType(context.inputType->kind()))); } return context.inputCount == 1 && context.inputType != nullptr && context.inputType->kind() == TypeKind::BIGINT; diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 9e1fa7692..387ebccd7 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -56,8 +56,9 @@ class MinMaxAggregate : public SimpleNumericAggregate { bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { return context.inputCount == 1 && context.inputType != nullptr && - !context.inputType->isRow() && !context.inputType->isDecimal() && - (jit::isHashAggrJitSupportedType(context.inputType->kind()) || + !context.inputType->isRow() && + (context.inputType->isDecimal() || + jit::isHashAggrJitSupportedType(context.inputType->kind()) || context.inputType->kind() == TypeKind::HUGEINT); } diff --git a/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp b/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp index bee6cd1a4..dc3ec7fa8 100644 --- a/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/CountAggregationTest.cpp @@ -132,6 +132,51 @@ TEST_F(CountAggregationTest, hashAggrJitBooleanCount) { assertEqualResults({partialFinalNoJit}, {partialFinalJit}); } +TEST_F(CountAggregationTest, hashAggrJitDecimalCount) { + auto shortDecimal = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 11; }), + makeFlatVector( + 256, + [](auto row) { return (row * 7) % 101; }, + [](auto row) { return row % 7 == 0; }, + DECIMAL(18, 3))}); + auto longDecimal = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 11; }), + makeFlatVector( + 256, + [](auto row) { return HugeInt::build(row % 9, (row * 31) % 997); }, + [](auto row) { return row % 7 == 0; }, + DECIMAL(38, 19))}); + + for (const auto& data : {shortDecimal, longDecimal}) { + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"count(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"count(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } + } +} + TEST_F(CountAggregationTest, mask) { std::vector data; // Make batches where some batches have mask all true, some half and half and diff --git a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp index 67b20660e..652ad1f9e 100644 --- a/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/MinMaxTest.cpp @@ -257,6 +257,80 @@ TEST_F(MinMaxTest, hashAggrJitBooleanMinMax) { } } +TEST_F(MinMaxTest, hashAggrJitShortDecimalMinMax) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 13; }), + makeFlatVector( + 256, + [](auto row) { return (row * 7) % 101; }, + [](auto row) { return row % 11 == 0; }, + DECIMAL(18, 3))}); + + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } +} + +TEST_F(MinMaxTest, hashAggrJitLongDecimalMinMax) { + auto data = makeRowVector( + {makeFlatVector(256, [](auto row) { return row % 13; }), + makeFlatVector( + 256, + [](auto row) { + return HugeInt::build(row % 9, (row * 31) % 997); + }, + [](auto row) { return row % 11 == 0; }, + DECIMAL(38, 19))}); + + for (const auto& makePlan : + std::vector>{ + [&]() { + return PlanBuilder() + .values({data}) + .singleAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .planNode(); + }, + [&]() { + return PlanBuilder() + .values({data}) + .partialAggregation({"c0"}, {"min(c1)", "max(c1)"}) + .finalAggregation() + .planNode(); + }}) { + auto plan = makePlan(); + auto noJit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); + } +} + TEST_F(MinMaxTest, constVarchar) { // Create two batches of the source data for the aggregation: // Column c0 with 1K of "apple" and 1K of "banana". From ac4026cba6b00b28ddc5454d4782685d7ba3b09f Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 6 Jun 2026 23:14:34 +0800 Subject: [PATCH 19/98] split addrawinput and addintermediateresults for hash aggr jit --- .../prestosql/aggregates/CountAggregate.cpp | 54 +++-- .../prestosql/aggregates/MinMaxAggregates.cpp | 35 +++- .../sparksql/aggregates/AverageAggregate.cpp | 186 ++++++++++-------- .../sparksql/aggregates/DecimalSumAggregate.h | 117 +++++------ .../sparksql/aggregates/SumAggregate.cpp | 35 +++- bolt/jit/aggregation/HashAggrJit.cpp | 13 +- bolt/jit/aggregation/HashAggrJit.h | 5 +- 7 files changed, 272 insertions(+), 173 deletions(-) diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 64eea87d7..367259e67 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -91,7 +91,7 @@ class CountAggregate : public SimpleNumericAggregate { } private: - static void compileHashAggrJitCreate( + static void compileHashAggrJitInitGroup( jit::HashAggrJitCodegen& codegen, llvm::Value* group, const jit::HashAggrJitSlot& slot) { @@ -102,24 +102,13 @@ class CountAggregate : public SimpleNumericAggregate { codegen.builder().getInt64(0)); } - static void compileHashAggrJitAdd( + static void addInc( jit::HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* state = codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); - llvm::Value* inc = nullptr; - if (slot.countStar || !slot.mergeInput) { - inc = codegen.builder().getInt64(1); - } else { - inc = codegen.castValue( - codegen.loadDecodedValue(decoded, row, slot), - slot.inputKind, - jit::HashAggrJitValueKind::Int64); - } + llvm::Value* inc) { + auto* state = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); codegen.storeValue( group, codegen.builder().getInt64Ty(), @@ -127,6 +116,34 @@ class CountAggregate : public SimpleNumericAggregate { codegen.builder().CreateAdd(state, inc)); } + static void compileHashAggrJitAddRawInput( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* /*decoded*/, + llvm::Value* /*row*/, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + addInc(codegen, group, slot, codegen.builder().getInt64(1)); + } + + static void compileHashAggrJitAddIntermediateResults( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + llvm::Value* inc = slot.countStar + ? codegen.builder().getInt64(1) + : codegen.castValue( + codegen.loadDecodedValue(decoded, row, slot), + slot.inputKind, + jit::HashAggrJitValueKind::Int64); + addInc(codegen, group, slot, inc); + } + static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { @@ -150,8 +167,9 @@ class CountAggregate : public SimpleNumericAggregate { static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ "count", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, + &compileHashAggrJitInitGroup, + &compileHashAggrJitAddRawInput, + &compileHashAggrJitAddIntermediateResults, &canCompileHashAggrJitExtract, &compileHashAggrJitExtract}; return &kOps; diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 387ebccd7..25a61b845 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -82,7 +82,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { } private: - static void compileHashAggrJitCreate( + static void compileHashAggrJitInitGroup( jit::HashAggrJitCodegen& codegen, llvm::Value* group, const jit::HashAggrJitSlot& slot) { @@ -95,7 +95,9 @@ class MinMaxAggregate : public SimpleNumericAggregate { } } - static void compileHashAggrJitAdd( + // min/max use the same logic for raw input and intermediate merge: both pick + // the better (min/max) of the decoded value and the current accumulator. + static void compileHashAggrJitUpdate( jit::HashAggrJitCodegen& codegen, llvm::Value* group, llvm::Value* decoded, @@ -140,6 +142,30 @@ class MinMaxAggregate : public SimpleNumericAggregate { codegen.clearAccumulatorNull(group, slot); } + static void compileHashAggrJitAddRawInput( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool checkInputNulls, + llvm::BasicBlock* nextBlock) { + compileHashAggrJitUpdate( + codegen, group, decoded, row, slot, checkInputNulls, nextBlock); + } + + static void compileHashAggrJitAddIntermediateResults( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool checkInputNulls, + llvm::BasicBlock* nextBlock) { + compileHashAggrJitUpdate( + codegen, group, decoded, row, slot, checkInputNulls, nextBlock); + } + static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot& slot, bool) { @@ -162,8 +188,9 @@ class MinMaxAggregate : public SimpleNumericAggregate { static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ "minmax", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, + &compileHashAggrJitInitGroup, + &compileHashAggrJitAddRawInput, + &compileHashAggrJitAddIntermediateResults, &canCompileHashAggrJitExtract, &compileHashAggrJitExtract}; return &kOps; diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 662b73ffe..d153f67fa 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -92,7 +92,7 @@ class AverageAggregate } private: - static void compileHashAggrJitCreate( + static void compileHashAggrJitInitGroup( jit::HashAggrJitCodegen& codegen, llvm::Value* group, const jit::HashAggrJitSlot& slot) { @@ -109,7 +109,7 @@ class AverageAggregate codegen.builder().getInt64(0)); } - static void compileHashAggrJitAdd( + static void compileHashAggrJitAddRawInput( jit::HashAggrJitCodegen& codegen, llvm::Value* group, llvm::Value* decoded, @@ -117,29 +117,6 @@ class AverageAggregate const jit::HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { - if (slot.mergeInput) { - codegen.clearAccumulatorNull(group, slot); - auto* sum = codegen.loadDecodedRowField( - decoded, row, 0, jit::HashAggrJitValueKind::Double); - auto* count = codegen.loadDecodedRowField( - decoded, row, 1, jit::HashAggrJitValueKind::Int64); - auto* oldSum = - codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); - codegen.storeValue( - group, - codegen.builder().getDoubleTy(), - slot.offset, - codegen.builder().CreateFAdd(oldSum, sum)); - auto* oldCount = codegen.loadValue( - group, codegen.builder().getInt64Ty(), slot.offset + 8); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().CreateAdd(oldCount, count)); - return; - } - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); auto* value = codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); @@ -160,6 +137,35 @@ class AverageAggregate codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); } + static void compileHashAggrJitAddIntermediateResults( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + codegen.clearAccumulatorNull(group, slot); + auto* sum = codegen.loadDecodedRowField( + decoded, row, 0, jit::HashAggrJitValueKind::Double); + auto* count = codegen.loadDecodedRowField( + decoded, row, 1, jit::HashAggrJitValueKind::Int64); + auto* oldSum = + codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getDoubleTy(), + slot.offset, + codegen.builder().CreateFAdd(oldSum, sum)); + auto* oldCount = codegen.loadValue( + group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, count)); + } + static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { @@ -175,8 +181,9 @@ class AverageAggregate static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ "avg", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, + &compileHashAggrJitInitGroup, + &compileHashAggrJitAddRawInput, + &compileHashAggrJitAddIntermediateResults, &canCompileHashAggrJitExtract, &compileHashAggrJitExtract}; return &kOps; @@ -512,7 +519,7 @@ class DecimalAverageAggregate : public DecimalAggregate { private: #ifdef ENABLE_BOLT_JIT - static void compileHashAggrJitCreate( + static void compileHashAggrJitInitGroup( jit::HashAggrJitCodegen& codegen, llvm::Value* group, const jit::HashAggrJitSlot& slot) { @@ -522,68 +529,14 @@ class DecimalAverageAggregate : public DecimalAggregate { {group, codegen.builder().getInt32(slot.offset)}); } - static void compileHashAggrJitAdd( + static void compileHashAggrJitAddRawInput( jit::HashAggrJitCodegen& codegen, llvm::Value* group, llvm::Value* decoded, llvm::Value* row, const jit::HashAggrJitSlot& slot, bool, - llvm::BasicBlock* nextBlock) { - if (slot.mergeInput) { - auto* function = codegen.builder().GetInsertBlock()->getParent(); - auto* continueBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge_cont", - function, - nextBlock); - auto* overflowBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge_overflow", - function, - continueBlock); - auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge", - function, - continueBlock); - auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); - auto* count = codegen.loadDecodedRowField( - decoded, row, 1, jit::HashAggrJitValueKind::Int64); - auto* countPositive = codegen.builder().CreateICmpSGT( - count, codegen.builder().getInt64(0)); - auto* isOverflow = codegen.builder().CreateAnd( - sumIsNull, - codegen.builder().CreateAnd( - codegen.builder().CreateNot(countIsNull), countPositive)); - codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); - - codegen.builder().SetInsertPoint(overflowBlock); - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrMergeDecimalAvgI128" - : "jit_HashAggrMergeDecimalAvgI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) - : sum, - count}); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(continueBlock); - return; - } - + llvm::BasicBlock*) { auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); codegen.clearAccumulatorNull(group, slot); const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 @@ -599,6 +552,66 @@ class DecimalAverageAggregate : public DecimalAggregate { : rawValue}); } + static void compileHashAggrJitAddIntermediateResults( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock* nextBlock) { + auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge", + function, + continueBlock); + auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); + auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); + auto* count = codegen.loadDecodedRowField( + decoded, row, 1, jit::HashAggrJitValueKind::Int64); + auto* countPositive = codegen.builder().CreateICmpSGT( + count, codegen.builder().getInt64(0)); + auto* isOverflow = codegen.builder().CreateAnd( + sumIsNull, + codegen.builder().CreateAnd( + codegen.builder().CreateNot(countIsNull), countPositive)); + codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + codegen.builder().SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(mergeBlock); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrMergeDecimalAvgI128" + : "jit_HashAggrMergeDecimalAvgI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) + : sum, + count}); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(continueBlock); + } + static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { @@ -614,8 +627,9 @@ class DecimalAverageAggregate : public DecimalAggregate { static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ "avg_decimal", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, + &compileHashAggrJitInitGroup, + &compileHashAggrJitAddRawInput, + &compileHashAggrJitAddIntermediateResults, &canCompileHashAggrJitExtract, &compileHashAggrJitExtract}; return &kOps; diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index d2dd73324..bc5209861 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -96,7 +96,7 @@ class DecimalSumAggregate : public exec::Aggregate { } private: - static void compileHashAggrJitCreate( + static void compileHashAggrJitInitGroup( jit::HashAggrJitCodegen& codegen, llvm::Value* group, const jit::HashAggrJitSlot& slot) { @@ -106,64 +106,14 @@ class DecimalSumAggregate : public exec::Aggregate { {group, codegen.builder().getInt32(slot.offset)}); } - static void compileHashAggrJitAdd( + static void compileHashAggrJitAddRawInput( jit::HashAggrJitCodegen& codegen, llvm::Value* group, llvm::Value* decoded, llvm::Value* row, const jit::HashAggrJitSlot& slot, bool, - llvm::BasicBlock* nextBlock) { - if (slot.mergeInput) { - auto* function = codegen.builder().GetInsertBlock()->getParent(); - auto* continueBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge_cont", - function, - nextBlock); - auto* overflowBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge_overflow", - function, - continueBlock); - auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge", - function, - continueBlock); - auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* isEmpty = codegen.loadDecodedRowField( - decoded, row, 1, jit::HashAggrJitValueKind::Int8); - auto* isNotEmpty = - codegen.builder().CreateICmpEQ(isEmpty, codegen.builder().getInt8(0)); - auto* isOverflow = codegen.builder().CreateAnd(sumIsNull, isNotEmpty); - codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); - - codegen.builder().SetInsertPoint(overflowBlock); - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrMergeDecimalSumI128" - : "jit_HashAggrMergeDecimalSumI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) - : sum, - isEmpty}); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(continueBlock); - return; - } - + llvm::BasicBlock*) { auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); codegen.clearAccumulatorNull(group, slot); const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 @@ -179,6 +129,62 @@ class DecimalSumAggregate : public exec::Aggregate { : rawValue}); } + static void compileHashAggrJitAddIntermediateResults( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool, + llvm::BasicBlock* nextBlock) { + auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge", + function, + continueBlock); + auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); + auto* isEmpty = codegen.loadDecodedRowField( + decoded, row, 1, jit::HashAggrJitValueKind::Int8); + auto* isNotEmpty = + codegen.builder().CreateICmpEQ(isEmpty, codegen.builder().getInt8(0)); + auto* isOverflow = codegen.builder().CreateAnd(sumIsNull, isNotEmpty); + codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + codegen.builder().SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(mergeBlock); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? "jit_HashAggrMergeDecimalSumI128" + : "jit_HashAggrMergeDecimalSumI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.inputKind == jit::HashAggrJitValueKind::Int128 + ? codegen.castValue( + sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) + : sum, + isEmpty}); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(continueBlock); + } + static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { @@ -194,8 +200,9 @@ class DecimalSumAggregate : public exec::Aggregate { static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ "sum_decimal", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, + &compileHashAggrJitInitGroup, + &compileHashAggrJitAddRawInput, + &compileHashAggrJitAddIntermediateResults, &canCompileHashAggrJitExtract, &compileHashAggrJitExtract}; return &kOps; diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index 622cb0e6e..b67219551 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -83,7 +83,7 @@ class SumAggregate : public SumAggregateBase { } private: - static void compileHashAggrJitCreate( + static void compileHashAggrJitInitGroup( jit::HashAggrJitCodegen& codegen, llvm::Value* group, const jit::HashAggrJitSlot& slot) { @@ -98,7 +98,9 @@ class SumAggregate : public SumAggregateBase { } } - static void compileHashAggrJitAdd( + // sum uses the same logic for raw input and intermediate merge: add the + // decoded value into the running accumulator. + static void compileHashAggrJitAccumulate( jit::HashAggrJitCodegen& codegen, llvm::Value* group, llvm::Value* decoded, @@ -118,6 +120,30 @@ class SumAggregate : public SumAggregateBase { codegen.storeValue(group, accType, slot.offset, newValue); } + static void compileHashAggrJitAddRawInput( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool checkInputNulls, + llvm::BasicBlock* nextBlock) { + compileHashAggrJitAccumulate( + codegen, group, decoded, row, slot, checkInputNulls, nextBlock); + } + + static void compileHashAggrJitAddIntermediateResults( + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const jit::HashAggrJitSlot& slot, + bool checkInputNulls, + llvm::BasicBlock* nextBlock) { + compileHashAggrJitAccumulate( + codegen, group, decoded, row, slot, checkInputNulls, nextBlock); + } + static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { @@ -133,8 +159,9 @@ class SumAggregate : public SumAggregateBase { static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ "sum", - &compileHashAggrJitCreate, - &compileHashAggrJitAdd, + &compileHashAggrJitInitGroup, + &compileHashAggrJitAddRawInput, + &compileHashAggrJitAddIntermediateResults, &canCompileHashAggrJitExtract, &compileHashAggrJitExtract}; return &kOps; diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index a16f8b852..a852b6cbd 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -582,10 +582,10 @@ bool genInitIR( auto* group = builder.CreateLoad(i8PtrTy, groupAddr); for (const auto& slot : slots) { - if (slot.ops == nullptr || slot.ops->create == nullptr) { + if (slot.ops == nullptr || slot.ops->initGroup == nullptr) { return true; } - slot.ops->create(codegen, group, slot); + slot.ops->initGroup(codegen, group, slot); } auto* next = builder.CreateAdd(index, builder.getInt32(1)); @@ -649,10 +649,15 @@ bool genAddDenseIR( } builder.SetInsertPoint(updateBlock); - if (slot.ops == nullptr || slot.ops->add == nullptr) { + if (slot.ops == nullptr) { return true; } - slot.ops->add(codegen, group, decoded, row, slot, checkInputNulls, nextBlock); + auto* addFn = + slot.mergeInput ? slot.ops->addIntermediateResults : slot.ops->addRawInput; + if (addFn == nullptr) { + return true; + } + addFn(codegen, group, decoded, row, slot, checkInputNulls, nextBlock); builder.CreateBr(nextBlock); builder.SetInsertPoint(nextBlock); } diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index c7875df09..78960f645 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -82,8 +82,9 @@ struct HashAggrJitOps { const HashAggrJitExtractTarget&); const char* id; - CreateFn create; - AddFn add; + CreateFn initGroup; + AddFn addRawInput; + AddFn addIntermediateResults; CanExtractFn canExtract; ExtractFn extract; }; From 8017cc51fce3390ef5f503b88298ffd40d6520cd Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 7 Jun 2026 08:28:22 +0800 Subject: [PATCH 20/98] fix bugs when extractxxx are executed in non-JIT way and numNulls_ is zero by wrong --- bolt/exec/Aggregate.cpp | 1 + bolt/exec/Aggregate.h | 8 ++++++ bolt/exec/GroupingSet.cpp | 10 +++++++ .../prestosql/aggregates/MinMaxAggregates.cpp | 1 + .../sparksql/aggregates/AverageAggregate.cpp | 3 +++ .../sparksql/aggregates/DecimalSumAggregate.h | 1 + .../sparksql/aggregates/SumAggregate.cpp | 1 + .../aggregates/tests/SumAggregationTest.cpp | 27 +++++++++++++++++++ bolt/jit/aggregation/HashAggrJit.h | 5 ++++ 9 files changed, 57 insertions(+) diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index 7ca9137d2..853762139 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -346,6 +346,7 @@ jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( .countStar = descriptor.countStar, .mergeInput = descriptor.mergeInput, .decimal = descriptor.decimal, + .initSetsNull = descriptor.initSetsNull, .ops = descriptor.ops}; } #endif diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index 301709f4c..34b5a1903 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -125,6 +125,14 @@ class Aggregate { jit::HashAggrJitSlot createHashAggrJitSlot( int32_t aggregateIndex, const jit::HashAggrJitDescriptor& descriptor) const; + + // HashAggr JIT initGroup marks accumulators as null by writing the null bit + // directly, bypassing setAllNulls/setNull. Since non-JIT extract relies on + // numNulls_ (see isNull()), GroupingSet must keep it in sync after running + // the JIT init path for the corresponding number of new groups. + void addNumNulls(uint64_t count) { + numNulls_ += count; + } #endif void setAllocator(HashStringAllocator* allocator) { diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 2f29db9d3..5cbaf1bad 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1050,6 +1050,16 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitNewGroups_[i] = groups[newGroups[i]]; } chunk.init(hashAggrJitNewGroups_.data(), newGroups.size()); + // JIT initGroup writes the null bit directly without touching + // Aggregate::numNulls_. Non-JIT extract relies on numNulls_ (isNull() + // short-circuits when it is 0), so keep it in sync here, mirroring the + // non-JIT initializeNewGroups/setAllNulls path. + for (const auto& slot : chunk.slots()) { + if (slot.initSetsNull) { + aggregates_[slot.aggregateIndex].function->addNumNulls( + newGroups.size()); + } + } LOG(INFO) << "HashAggrJit initialized new groups for chunk " << chunk.functionName() << " newGroups=" << newGroups.size(); } diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 25a61b845..73b52e23b 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -78,6 +78,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { false, !context.isRawInput, false, + /*initSetsNull=*/true, hashAggrJitOps()}; } diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index d153f67fa..6493748d0 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -74,6 +74,7 @@ class AverageAggregate false, true, false, + /*initSetsNull=*/true, hashAggrJitOps()}; } @@ -88,6 +89,7 @@ class AverageAggregate false, false, false, + /*initSetsNull=*/true, hashAggrJitOps()}; } @@ -274,6 +276,7 @@ class DecimalAverageAggregate : public DecimalAggregate { false, !context.isRawInput, true, + /*initSetsNull=*/true, hashAggrJitOps()}; } #endif diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index bc5209861..6b3d4a2c4 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -92,6 +92,7 @@ class DecimalSumAggregate : public exec::Aggregate { false, !context.isRawInput, true, + /*initSetsNull=*/true, hashAggrJitOps()}; } diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index b67219551..52cb70d55 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -79,6 +79,7 @@ class SumAggregate : public SumAggregateBase { false, !context.isRawInput, false, + /*initSetsNull=*/true, hashAggrJitOps()}; } diff --git a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 5d89a6cab..6c8a45f59 100644 --- a/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/bolt/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -215,6 +215,33 @@ TEST_F(SumAggregationTest, hashAggrJitPartialAvgExtractAccumulators) { assertEqualResults({noJit}, {jit}); } +TEST_F(SumAggregationTest, hashAggrJitAllNullGroup) { + // Repro: one group (c0 == 1) has all-null sum input. Spark sum over an + // all-null group must yield null, not 0. Partial+final two-stage plan. + auto input = makeRowVector( + {makeFlatVector(8, [](auto row) { return row % 2; }), + makeFlatVector( + 8, + [](auto row) { return static_cast(row); }, + // c0 == 1 rows (odd rows) are all null. + [](auto row) { return row % 2 == 1; })}); + + auto plan = PlanBuilder(pool()) + .values({input}) + .partialAggregation({"c0"}, {"spark_sum(c1)"}) + .finalAggregation() + .planNode(); + + auto noJit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "false") + .copyResults(pool()); + auto jit = AssertQueryBuilder(plan) + .config(core::QueryConfig::kHashAggrJitEnabled, "true") + .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "1") + .copyResults(pool()); + assertEqualResults({noJit}, {jit}); +} + TEST_F(SumAggregationTest, hashAggrJitSplitsContiguousSegments) { auto input = makeRowVector( {makeFlatVector(512, [](auto row) { return row % 16; }), diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 78960f645..d1c5fac94 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -58,6 +58,10 @@ struct HashAggrJitDescriptor { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; + // Whether initGroup marks the accumulator as null for each new group. When + // true, GroupingSet must keep Aggregate::numNulls_ in sync (non-JIT extract + // relies on it), mirroring the non-JIT initializeNewGroups path. + bool initSetsNull{false}; const struct HashAggrJitOps* ops{nullptr}; std::string signature() const; @@ -100,6 +104,7 @@ struct HashAggrJitSlot { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; + bool initSetsNull{false}; const HashAggrJitOps* ops{nullptr}; }; From ca78493abe5bbdbedde9f0976ec6003a7675cee0 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 7 Jun 2026 15:33:17 +0800 Subject: [PATCH 21/98] support exactxxx in jit --- bolt/exec/Aggregate.cpp | 4 + .../prestosql/aggregates/CountAggregate.cpp | 8 +- .../prestosql/aggregates/MinMaxAggregates.cpp | 10 +- .../sparksql/aggregates/AverageAggregate.cpp | 128 ++++++++++++++++-- .../sparksql/aggregates/DecimalSumAggregate.h | 19 ++- .../sparksql/aggregates/SumAggregate.cpp | 119 +++++++++++++++- bolt/jit/aggregation/HashAggrJit.cpp | 79 +++++++++-- bolt/jit/aggregation/HashAggrJit.h | 43 ++++++ 8 files changed, 372 insertions(+), 38 deletions(-) diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index 853762139..1aebd48ec 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -347,6 +347,10 @@ jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( .mergeInput = descriptor.mergeInput, .decimal = descriptor.decimal, .initSetsNull = descriptor.initSetsNull, + .precision = descriptor.precision, + .scale = descriptor.scale, + .auxPrecision = descriptor.auxPrecision, + .auxScale = descriptor.auxScale, .ops = descriptor.ops}; } #endif diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 367259e67..2b98db0d2 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -87,6 +87,11 @@ class CountAggregate : public SimpleNumericAggregate { context.isCountStar(), !context.isRawInput, false, + /*initSetsNull=*/false, + /*precision=*/0, + /*scale=*/0, + /*auxPrecision=*/0, + /*auxScale=*/0, hashAggrJitOps()}; } @@ -147,7 +152,8 @@ class CountAggregate : public SimpleNumericAggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { - return false; + // count result is always BIGINT and never null. + return true; } static void compileHashAggrJitExtract( diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 73b52e23b..20c34620c 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -79,6 +79,10 @@ class MinMaxAggregate : public SimpleNumericAggregate { !context.isRawInput, false, /*initSetsNull=*/true, + /*precision=*/0, + /*scale=*/0, + /*auxPrecision=*/0, + /*auxScale=*/0, hashAggrJitOps()}; } @@ -170,8 +174,10 @@ class MinMaxAggregate : public SimpleNumericAggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot& slot, bool) { - // return slot.accumulatorKind != jit::HashAggrJitValueKind::Int128; - return false; + // Flat setters exist for i8/i16/i32/i64/f32/f64 only. Int128 (long decimal) + // and Bool have no flat setter yet, fall back to non-JIT extract. + return slot.accumulatorKind != jit::HashAggrJitValueKind::Int128 && + slot.accumulatorKind != jit::HashAggrJitValueKind::Bool; } static void compileHashAggrJitExtract( diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 6493748d0..78aea72d5 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -31,6 +31,62 @@ #include "bolt/functions/sparksql/aggregates/AverageAggregate.h" #include "bolt/functions/lib/aggregates/AverageAggregateBase.h" #include "bolt/functions/sparksql/DecimalUtil.h" + +#ifdef ENABLE_BOLT_JIT +#include "bolt/jit/aggregation/HashAggrJit.h" +#include "bolt/type/DecimalUtil.h" + +extern "C" { + +// Partial decimal avg extract: write row(sum:decimal, count:bigint). +// Overflow during sum adjustment -> sum child set to null, count kept. +// (Final decimal avg extract stays on the non-JIT path; the rescale logic is +// too coupled to per-aggregate precision metadata.) +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t /*precision*/, + int32_t /*scale*/, + int8_t /*longDecimal*/) { + auto* state = + reinterpret_cast(group + offset); + auto* rowVector = + reinterpret_cast(vector) + ->as(); + auto* sumVector = + rowVector->childAt(0)->asFlatVector(); + auto* countVector = rowVector->childAt(1)->asFlatVector(); + rowVector->setNull(row, false); + countVector->set(row, state->count); + std::optional adjustedSum = + bytedance::bolt::DecimalUtil::adjustSumForOverflow( + state->sum, state->overflow); + if (adjustedSum.has_value()) { + sumVector->set(row, adjustedSum.value()); + } else { + sumVector->setNull(row, true); + } +} + +// Final decimal avg extract is intentionally not implemented in JIT; the +// declaration exists so the JIT module link succeeds, but it is never called +// because canExtract returns false for the final (non-partial) output. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalDecimalAvg( + char* /*vector*/, + int32_t /*row*/, + char* /*group*/, + int32_t /*offset*/, + int32_t /*precision*/, + int32_t /*scale*/, + int8_t /*longDecimal*/) {} + +} // extern "C" +#endif + using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { namespace { @@ -75,6 +131,10 @@ class AverageAggregate true, false, /*initSetsNull=*/true, + /*precision=*/0, + /*scale=*/0, + /*auxPrecision=*/0, + /*auxScale=*/0, hashAggrJitOps()}; } @@ -90,6 +150,10 @@ class AverageAggregate false, false, /*initSetsNull=*/true, + /*precision=*/0, + /*scale=*/0, + /*auxPrecision=*/0, + /*auxScale=*/0, hashAggrJitOps()}; } @@ -169,16 +233,42 @@ class AverageAggregate } static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot&, + const jit::HashAggrJitSlot& slot, bool) { - return false; + // Only double avg (sum=double@offset, count=int64@offset+8) is supported. + return slot.accumulatorKind == jit::HashAggrJitValueKind::Double; } static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen&, - llvm::Value*, - const jit::HashAggrJitSlot&, - const jit::HashAggrJitExtractTarget&) {} + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + auto& builder = codegen.builder(); + auto* sum = + codegen.loadValue(group, builder.getDoubleTy(), slot.offset); + auto* count = + codegen.loadValue(group, builder.getInt64Ty(), slot.offset + 8); + if (target.partialOutput) { + // Intermediate output is row(sum:double, count:bigint). All-null group + // yields (0, 0) with a non-null top-level row (isNull = 0), matching the + // non-JIT extractAccumulators path. + codegen.emitPartialAvgResult( + target.resultVector, target.row, sum, count, builder.getInt8(0)); + return; + } + // Final output is double avg. count == 0 means all inputs were null -> null. + auto* isNull = builder.CreateZExt( + builder.CreateICmpEQ(count, builder.getInt64(0)), builder.getInt8Ty()); + auto* countAsDouble = builder.CreateSIToFP(count, builder.getDoubleTy()); + auto* avg = builder.CreateFDiv(sum, countAsDouble); + codegen.emitFlatValue( + target.resultVector, + target.row, + jit::HashAggrJitValueKind::Double, + avg, + isNull); + } static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ @@ -268,6 +358,10 @@ class DecimalAverageAggregate : public DecimalAggregate { } const auto& valueType = context.isRawInput ? context.inputType : context.inputType->childAt(0); + const auto [sumPrecision, sumScale] = + getDecimalPrecisionScale(*sumType_.get()); + const auto [resultPrecision, resultScale] = + getDecimalPrecisionScale(*this->resultType().get()); return jit::HashAggrJitDescriptor{ jit::HashAggrJitKind::Avg, valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 @@ -277,6 +371,10 @@ class DecimalAverageAggregate : public DecimalAggregate { !context.isRawInput, true, /*initSetsNull=*/true, + /*precision=*/sumPrecision, + /*scale=*/sumScale, + /*auxPrecision=*/resultPrecision, + /*auxScale=*/resultScale, hashAggrJitOps()}; } #endif @@ -617,15 +715,21 @@ class DecimalAverageAggregate : public DecimalAggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, - bool) { - return false; + bool partialOutput) { + // Only the partial (extractAccumulators) path is JIT-supported for decimal + // avg. Final avg needs the full per-aggregate rescale logic and stays on + // the non-JIT path. + return partialOutput; } static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen&, - llvm::Value*, - const jit::HashAggrJitSlot&, - const jit::HashAggrJitExtractTarget&) {} + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + codegen.emitDecimalAvgExtract( + target.resultVector, target.row, group, slot, target.partialOutput); + } static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 6b3d4a2c4..ffdf16e25 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -84,6 +84,8 @@ class DecimalSumAggregate : public exec::Aggregate { } const auto& valueType = context.isRawInput ? context.inputType : context.inputType->childAt(0); + const auto [resultPrecision, resultScale] = + getDecimalPrecisionScale(*sumType_.get()); return jit::HashAggrJitDescriptor{ jit::HashAggrJitKind::Sum, valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 @@ -93,6 +95,10 @@ class DecimalSumAggregate : public exec::Aggregate { !context.isRawInput, true, /*initSetsNull=*/true, + /*precision=*/resultPrecision, + /*scale=*/resultScale, + /*auxPrecision=*/0, + /*auxScale=*/0, hashAggrJitOps()}; } @@ -189,14 +195,17 @@ class DecimalSumAggregate : public exec::Aggregate { static bool canCompileHashAggrJitExtract( const jit::HashAggrJitSlot&, bool) { - return false; + return true; } static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen&, - llvm::Value*, - const jit::HashAggrJitSlot&, - const jit::HashAggrJitExtractTarget&) {} + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + codegen.emitDecimalSumExtract( + target.resultVector, target.row, group, slot, target.partialOutput); + } static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index 52cb70d55..44f65ca8d 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -32,6 +32,100 @@ #include "bolt/functions/lib/aggregates/SumAggregateBase.h" #include "bolt/functions/sparksql/aggregates/DecimalSumAggregate.h" + +#ifdef ENABLE_BOLT_JIT +#include "bolt/type/DecimalUtil.h" + +namespace { +// Mirrors DecimalSumAggregate::computeFinalValue: applies overflow adjustment +// and reports whether the value overflows the result precision range. +bytedance::bolt::int128_t jitDecimalSumComputeFinal( + const bytedance::bolt::jit::JitDecimalSumState* state, + int32_t precision, + bool& overflow) { + using bytedance::bolt::DecimalUtil; + bytedance::bolt::int128_t sum = state->sum; + if ((state->overflow == 1 && state->sum < 0) || + (state->overflow == -1 && state->sum > 0)) { + sum = static_cast( + DecimalUtil::kOverflowMultiplier * state->overflow + state->sum); + } else if (state->overflow != 0) { + overflow = true; + return 0; + } + overflow = !DecimalUtil::valueInPrecisionRange(sum, precision); + return sum; +} +} // namespace + +extern "C" { + +// Final decimal sum extract: write FlatVector. Null when the group is +// empty (all inputs null) or the sum overflows the result precision. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int8_t /*longDecimal*/) { + auto* state = + reinterpret_cast(group + offset); + auto* flat = reinterpret_cast(vector) + ->as>(); + if (state->isEmpty) { + flat->setNull(row, true); + return; + } + bool overflow = false; + auto result = jitDecimalSumComputeFinal(state, precision, overflow); + if (overflow) { + flat->setNull(row, true); + } else { + flat->set(row, result); + } +} + +// Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int8_t /*longDecimal*/) { + auto* state = + reinterpret_cast(group + offset); + auto* rowVector = + reinterpret_cast(vector) + ->as(); + auto* sumVector = rowVector->childAt(0) + ->asFlatVector(); + auto* isEmptyVector = rowVector->childAt(1)->asFlatVector(); + rowVector->setNull(row, false); + if (state->isEmpty) { + sumVector->set(row, 0); + isEmptyVector->set(row, true); + return; + } + bool overflow = false; + auto result = jitDecimalSumComputeFinal(state, precision, overflow); + if (overflow) { + sumVector->setNull(row, true); + isEmptyVector->set(row, false); + } else { + sumVector->set(row, result); + isEmptyVector->set(row, state->isEmpty); + } +} + +} // extern "C" +#endif + using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { @@ -80,6 +174,10 @@ class SumAggregate : public SumAggregateBase { !context.isRawInput, false, /*initSetsNull=*/true, + /*precision=*/0, + /*scale=*/0, + /*auxPrecision=*/0, + /*auxScale=*/0, hashAggrJitOps()}; } @@ -146,16 +244,25 @@ class SumAggregate : public SumAggregateBase { } static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot&, + const jit::HashAggrJitSlot& slot, bool) { - return false; + // spark sum intermediate type == result type (bigint=bigint / double=double). + return slot.accumulatorKind == jit::HashAggrJitValueKind::Int64 || + slot.accumulatorKind == jit::HashAggrJitValueKind::Double; } static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen&, - llvm::Value*, - const jit::HashAggrJitSlot&, - const jit::HashAggrJitExtractTarget&) {} + jit::HashAggrJitCodegen& codegen, + llvm::Value* group, + const jit::HashAggrJitSlot& slot, + const jit::HashAggrJitExtractTarget& target) { + auto* value = + codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + auto* isNull = codegen.builder().CreateZExt( + codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); + codegen.emitFlatValue( + target.resultVector, target.row, slot.accumulatorKind, value, isNull); + } static const jit::HashAggrJitOps* hashAggrJitOps() { static const jit::HashAggrJitOps kOps{ diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index a852b6cbd..6ed7973ac 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -18,6 +18,9 @@ extern "C" { +using bytedance::bolt::jit::JitDecimalAvgState; +using bytedance::bolt::jit::JitDecimalSumState; + namespace { void logHashAggrJitFunctionIR( @@ -46,18 +49,6 @@ void logHashAggrJitFunctionIR( << ir; } -struct JitDecimalSumState { - bytedance::bolt::int128_t sum{0}; - int64_t overflow{0}; - bool isEmpty{true}; -}; - -struct JitDecimalAvgState { - bytedance::bolt::int128_t sum{0}; - int64_t count{0}; - int64_t overflow{0}; -}; - int64_t jitHashAggrAddWithOverflow( bytedance::bolt::int128_t left, bytedance::bolt::int128_t right, @@ -255,6 +246,28 @@ void ensureBuiltinDeclarations(llvm::Module& module) { "jit_HashAggrSetPartialAvgDouble", voidTy, {i8PtrTy, i32Ty, doubleTy, i64Ty, i8Ty}); + // Decimal extract helpers: (vector, row, group, offset, precision, scale, + // longDecimal). + declareFunction( + module, + "jit_HashAggrExtractFinalDecimalSum", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialDecimalSum", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + declareFunction( + module, + "jit_HashAggrExtractFinalDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); } llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { @@ -539,6 +552,48 @@ void HashAggrJitCodegen::emitPartialAvgResult( {vector, row, sum, count, isNull}); } +void HashAggrJitCodegen::emitDecimalSumExtract( + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot, + bool partialOutput) const { + const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalSum" + : "jit_HashAggrExtractFinalDecimalSum"; + auto* longDecimal = builder().getInt8( + slot.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + builder().CreateCall( + module_.getFunction(fn), + {vector, + row, + group, + builder().getInt32(slot.offset), + builder().getInt32(slot.precision), + builder().getInt32(slot.scale), + longDecimal}); +} + +void HashAggrJitCodegen::emitDecimalAvgExtract( + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot, + bool partialOutput) const { + const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalAvg" + : "jit_HashAggrExtractFinalDecimalAvg"; + auto* longDecimal = builder().getInt8( + slot.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + builder().CreateCall( + module_.getFunction(fn), + {vector, + row, + group, + builder().getInt32(slot.offset), + builder().getInt32(slot.precision), + builder().getInt32(slot.scale), + longDecimal}); +} + namespace { bool genAddDenseIR( diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index d1c5fac94..027462e73 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -18,6 +18,22 @@ namespace bytedance::bolt::jit { class HashAggrJitCodegen; + +// JIT-internal accumulator layouts for decimal sum/avg. Shared between the JIT +// codegen runtime helpers and the extract runtime helpers (which live in a +// different translation unit and need DecimalUtil). +struct JitDecimalSumState { + bytedance::bolt::int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; +}; + +struct JitDecimalAvgState { + bytedance::bolt::int128_t sum{0}; + int64_t count{0}; + int64_t overflow{0}; +}; + struct HashAggrJitSlot; struct HashAggrJitExtractTarget; @@ -62,6 +78,14 @@ struct HashAggrJitDescriptor { // true, GroupingSet must keep Aggregate::numNulls_ in sync (non-JIT extract // relies on it), mirroring the non-JIT initializeNewGroups path. bool initSetsNull{false}; + // Result decimal precision/scale, used by decimal extract overflow checks. + // Only meaningful when decimal == true. + int32_t precision{0}; + int32_t scale{0}; + // Secondary decimal precision/scale. For decimal avg extract, precision/scale + // carry the intermediate sum type and aux* carry the result type. + int32_t auxPrecision{0}; + int32_t auxScale{0}; const struct HashAggrJitOps* ops{nullptr}; std::string signature() const; @@ -105,6 +129,10 @@ struct HashAggrJitSlot { bool mergeInput{false}; bool decimal{false}; bool initSetsNull{false}; + int32_t precision{0}; + int32_t scale{0}; + int32_t auxPrecision{0}; + int32_t auxScale{0}; const HashAggrJitOps* ops{nullptr}; }; @@ -176,6 +204,21 @@ class HashAggrJitCodegen { llvm::Value* sum, llvm::Value* count, llvm::Value* isNull) const; + // Decimal extract: calls a runtime helper that reads the JIT decimal + // accumulator from 'group + slot.offset', applies overflow/precision checks + // and writes the result (final flat decimal / partial row) into 'vector'. + void emitDecimalSumExtract( + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot, + bool partialOutput) const; + void emitDecimalAvgExtract( + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot, + bool partialOutput) const; private: llvm::Module& module_; From 741ae9332911372972cc1a0ecdedcd7116b27cc4 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 7 Jun 2026 18:39:43 +0800 Subject: [PATCH 22/98] reset log level for jit --- bolt/exec/GroupingSet.cpp | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 5cbaf1bad..98f15e502 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -873,7 +873,7 @@ const SelectivityVector& GroupingSet::getSelectivityVector( void GroupingSet::maybeCreateHashAggrJitPlan() { hashAggrJitChunks_.clear(); if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { - LOG(INFO) << "HashAggrJit plan disabled: enableHashAggrJit=" + VLOG(1) << "HashAggrJit plan disabled: enableHashAggrJit=" << queryConfig_.enableHashAggrJit() << " isGlobal=" << isGlobal_ << " ignoreNullKeys=" << ignoreNullKeys_; return; @@ -884,7 +884,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto maxFuseWidth = std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); const auto minChunkWidth = minFuseWidth; - LOG(INFO) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ + VLOG(1) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ << " isPartial=" << isPartial_ << " aggregates=" << aggregates_.size() << " minFuseWidth=" << minFuseWidth @@ -904,7 +904,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto& slot = currentChunkSlots[i]; out << hashAggrJitSlotDebugString(slot, &aggregates_[slot.aggregateIndex]); } - LOG(INFO) << "HashAggrJit discard chunk candidate due to width " + VLOG(1) << "HashAggrJit discard chunk candidate due to width " << currentChunkSlots.size() << " < " << minChunkWidth << ": [" << out.str() << "]"; } @@ -914,11 +914,11 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { jit::HashAggrJitChunk chunk(std::move(currentChunkSlots), isPartial_); if (chunk.codegen()) { hashAggrJitChunks_.push_back(std::move(chunk)); - LOG(INFO) << "HashAggrJit formed chunk: " + VLOG(1) << "HashAggrJit formed chunk: " << hashAggrJitChunkDebugString( hashAggrJitChunks_.back(), aggregates_); } else { - LOG(INFO) << "HashAggrJit chunk codegen failed for chunk " + VLOG(1) << "HashAggrJit chunk codegen failed for chunk " << chunk.functionName(); } currentChunkSlots.clear(); @@ -928,7 +928,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { for (auto i = 0; i < aggregates_.size(); ++i) { auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); if (!slot.has_value()) { - LOG(INFO) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" + VLOG(1) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" << aggregates_[i].name << ") rawInputTypes=[" << [&]() { std::ostringstream out; @@ -953,13 +953,13 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { if (currentChunkSlots.size() >= maxFuseWidth) { flushChunk(); } - LOG(INFO) << "HashAggrJit aggregate is JIT-able: " + VLOG(1) << "HashAggrJit aggregate is JIT-able: " << hashAggrJitSlotDebugString(*slot, &aggregates_[i]); currentChunkSlots.push_back(*slot); } flushChunk(); - LOG(INFO) << "HashAggrJit planning finished: totalChunks=" + VLOG(1) << "HashAggrJit planning finished: totalChunks=" << hashAggrJitChunks_.size(); } @@ -971,7 +971,7 @@ void GroupingSet::runHashAggrJitChunks( std::vector& jitExecuted) { if (hashAggrJitChunks_.empty() || hasSpilled() || bypassProbeHT_ || supportRowBasedOutput_ || !activeRows_.isAllSelected()) { - LOG(INFO) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() + VLOG(1) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() << " hasSpilled=" << hasSpilled() << " bypassProbeHT=" << bypassProbeHT_ << " supportRowBasedOutput=" << supportRowBasedOutput_ @@ -982,7 +982,7 @@ void GroupingSet::runHashAggrJitChunks( jitExecuted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { if (!chunk.enabled()) { - LOG(INFO) << "HashAggrJit chunk disabled, skip add: " + VLOG(1) << "HashAggrJit chunk disabled, skip add: " << hashAggrJitChunkDebugString(chunk, aggregates_); continue; } @@ -1038,7 +1038,7 @@ void GroupingSet::runHashAggrJitChunks( } if (!canRunChunk) { - LOG(INFO) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " + VLOG(1) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " << hashAggrJitChunkDebugString(chunk, aggregates_) << " reason=" << skipReason; continue; @@ -1060,7 +1060,7 @@ void GroupingSet::runHashAggrJitChunks( newGroups.size()); } } - LOG(INFO) << "HashAggrJit initialized new groups for chunk " + VLOG(1) << "HashAggrJit initialized new groups for chunk " << chunk.functionName() << " newGroups=" << newGroups.size(); } @@ -1069,13 +1069,13 @@ void GroupingSet::runHashAggrJitChunks( activeRows_.end(), hashAggrJitDecodedPtrs_.data(), inputsMayHaveNulls); - LOG(INFO) << "HashAggrJit add executed: chunk=" << chunk.functionName() + VLOG(1) << "HashAggrJit add executed: chunk=" << chunk.functionName() << " rows=" << activeRows_.end() << " inputsMayHaveNulls=" << inputsMayHaveNulls << " slots=" << hashAggrJitChunkDebugString(chunk, aggregates_); for (const auto& slot : chunk.slots()) { jitExecuted[slot.aggregateIndex] = 1; - LOG(INFO) << "HashAggrJit slot executed in add path: " + VLOG(1) << "HashAggrJit slot executed in add path: " << hashAggrJitSlotDebugString( slot, &aggregates_[slot.aggregateIndex]); } @@ -1089,7 +1089,7 @@ void GroupingSet::runHashAggrJitExtractChunks( std::vector& jitExtracted) { if (hashAggrJitChunks_.empty() || groups.empty() || hasSpilled() || supportRowBasedOutput_) { - LOG(INFO) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() + VLOG(1) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() << " groups=" << groups.size() << " hasSpilled=" << hasSpilled() << " supportRowBasedOutput=" << supportRowBasedOutput_; return; @@ -1098,7 +1098,7 @@ void GroupingSet::runHashAggrJitExtractChunks( jitExtracted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { if (!chunk.canExtract()) { - LOG(INFO) << "HashAggrJit chunk cannot extract, fallback to non-JIT extract: " + VLOG(1) << "HashAggrJit chunk cannot extract, fallback to non-JIT extract: " << hashAggrJitChunkDebugString(chunk, aggregates_); continue; } @@ -1128,18 +1128,18 @@ void GroupingSet::runHashAggrJitExtractChunks( hashAggrJitResultPtrs_[slotIndex] = reinterpret_cast(aggregateVector.get()); } if (!canRunChunk) { - LOG(INFO) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " + VLOG(1) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " << hashAggrJitChunkDebugString(chunk, aggregates_) << " reason=" << skipReason; continue; } chunk.extract(groups.data(), groups.size(), hashAggrJitResultPtrs_.data()); - LOG(INFO) << "HashAggrJit extract executed: chunk=" << chunk.functionName() + VLOG(1) << "HashAggrJit extract executed: chunk=" << chunk.functionName() << " groups=" << groups.size() << " slots=" << hashAggrJitChunkDebugString(chunk, aggregates_); for (const auto& slot : chunk.slots()) { jitExtracted[slot.aggregateIndex] = 1; - LOG(INFO) << "HashAggrJit slot executed in extract path: " + VLOG(1) << "HashAggrJit slot executed in extract path: " << hashAggrJitSlotDebugString( slot, &aggregates_[slot.aggregateIndex]); } From f947804984e3f76e60f8a9ba6642ad4ed6e1788b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 8 Jun 2026 14:45:08 +0800 Subject: [PATCH 23/98] remove helper function from jit for dict/flat/const encodings --- bolt/exec/GroupingSet.cpp | 9 +- bolt/exec/GroupingSet.h | 1 + bolt/jit/aggregation/HashAggrJit.cpp | 105 ++++++++++++-- bolt/jit/aggregation/HashAggrJit.h | 24 ++++ bolt/vector/DecodedVector.h | 4 + doc/hashaggr-jit-benchmark.md | 205 +++++++++++++++++++++++++++ 6 files changed, 337 insertions(+), 11 deletions(-) create mode 100644 doc/hashaggr-jit-benchmark.md diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 98f15e502..7f83c8ab1 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -989,6 +989,7 @@ void GroupingSet::runHashAggrJitChunks( const auto numSlots = chunk.slots().size(); hashAggrJitDecoded_.resize(numSlots); + hashAggrJitDecodedInputs_.resize(numSlots); hashAggrJitInputVectors_.assign(numSlots, nullptr); hashAggrJitDecodedPtrs_.assign(numSlots, nullptr); @@ -1032,9 +1033,15 @@ void GroupingSet::runHashAggrJitChunks( } hashAggrJitInputVectors_[slotIndex] = arg; hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); + hashAggrJitDecodedInputs_[slotIndex] = jit::HashAggrJitDecodedInput{ + hashAggrJitDecoded_[slotIndex].dataAsVoid(), + hashAggrJitDecoded_[slotIndex].indices(), + hashAggrJitDecoded_[slotIndex].nulls(&activeRows_), + &hashAggrJitDecoded_[slotIndex]}; inputsMayHaveNulls = inputsMayHaveNulls || hashAggrJitDecoded_[slotIndex].mayHaveNulls(); - hashAggrJitDecodedPtrs_[slotIndex] = reinterpret_cast(&hashAggrJitDecoded_[slotIndex]); + hashAggrJitDecodedPtrs_[slotIndex] = + reinterpret_cast(&hashAggrJitDecodedInputs_[slotIndex]); } if (!canRunChunk) { diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 0781910b5..f8ca2c43a 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -464,6 +464,7 @@ class GroupingSet { #ifdef ENABLE_BOLT_JIT std::vector hashAggrJitChunks_; std::vector hashAggrJitDecoded_; + std::vector hashAggrJitDecodedInputs_; std::vector hashAggrJitInputVectors_; std::vector hashAggrJitDecodedPtrs_; std::vector hashAggrJitNewGroups_; diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 6ed7973ac..96f8a1ae4 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -427,12 +427,80 @@ void setAccumulatorNull( llvm::Value* loadDecodedValue( llvm::IRBuilder<>& builder, - llvm::Module& module, llvm::Value* decoded, llvm::Value* row, const HashAggrJitSlot& slot) { - auto* callee = module.getFunction(decodedValueFunction(slot.inputKind)); - return builder.CreateCall(callee, {decoded, row}); + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* i32Ty = builder.getInt32Ty(); + + auto* valuesPtrPtr = builder.CreatePointerCast(decoded, i8PtrTy->getPointerTo()); + auto* values = builder.CreateLoad(i8PtrTy, valuesPtrPtr, "decoded_values"); + + auto* indicesAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), decoded, static_cast(sizeof(void*))); + auto* indicesPtrPtr = + builder.CreatePointerCast(indicesAddr, i32Ty->getPointerTo()->getPointerTo()); + auto* indices = builder.CreateLoad(i32Ty->getPointerTo(), indicesPtrPtr, "decoded_indices"); + auto* index = builder.CreateLoad( + i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); + + if (slot.inputKind == HashAggrJitValueKind::Bool) { + auto* wordTy = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(index, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(index, builder.getInt32(63)); + auto* words = builder.CreatePointerCast(values, wordTy->getPointerTo()); + auto* word = builder.CreateLoad( + wordTy, + builder.CreateInBoundsGEP( + wordTy, words, builder.CreateZExt(wordIndex, builder.getInt64Ty()))); + auto* shifted = builder.CreateLShr(word, builder.CreateZExt(bitIndex, wordTy)); + return builder.CreateZExt( + builder.CreateICmpNE( + builder.CreateAnd(shifted, builder.getInt64(1)), builder.getInt64(0)), + builder.getInt8Ty()); + } + + auto* type = llvmType(builder, slot.inputKind); + auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); + auto* valueAddr = builder.CreateInBoundsGEP( + type, typedValues, builder.CreateZExt(index, builder.getInt64Ty())); + auto* load = builder.CreateLoad(type, valueAddr); + load->setAlignment(llvm::Align(1)); + return load; +} + +llvm::Value* loadDecodedNulls(llvm::IRBuilder<>& builder, llvm::Value* decoded) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* nullsAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), decoded, static_cast(2 * sizeof(void*))); + auto* nullsPtrPtr = builder.CreatePointerCast(nullsAddr, i8PtrTy->getPointerTo()); + return builder.CreateLoad(i8PtrTy, nullsPtrPtr, "decoded_nulls"); +} + +llvm::Value* isDecodedNull( + llvm::IRBuilder<>& builder, + llvm::Value* nulls, + llvm::Value* row) { + auto* i64Ty = builder.getInt64Ty(); + auto* nullWords = builder.CreatePointerCast(nulls, i64Ty->getPointerTo()); + auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(row, builder.getInt32(63)); + auto* word = builder.CreateLoad( + i64Ty, + builder.CreateInBoundsGEP( + i64Ty, nullWords, builder.CreateZExt(wordIndex, builder.getInt64Ty()))); + auto* shifted = builder.CreateLShr(word, builder.CreateZExt(bitIndex, i64Ty)); + return builder.CreateICmpNE( + builder.CreateAnd(shifted, builder.getInt64(1)), builder.getInt64(0)); +} + +llvm::Value* loadDecodedVector(llvm::IRBuilder<>& builder, llvm::Value* decoded) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* decodedVectorAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), decoded, static_cast(3 * sizeof(void*))); + auto* decodedVectorPtrPtr = + builder.CreatePointerCast(decodedVectorAddr, i8PtrTy->getPointerTo()); + return builder.CreateLoad(i8PtrTy, decodedVectorPtrPtr, "decoded_vector"); } } // namespace @@ -449,8 +517,17 @@ llvm::Value* HashAggrJitCodegen::loadDecodedValue( llvm::Value* decoded, llvm::Value* row, const HashAggrJitSlot& slot) const { - return ::bytedance::bolt::jit::loadDecodedValue( - builder(), module_, decoded, row, slot); + return ::bytedance::bolt::jit::loadDecodedValue(builder(), decoded, row, slot); +} + +llvm::Value* HashAggrJitCodegen::loadDecodedNulls(llvm::Value* decoded) const { + return ::bytedance::bolt::jit::loadDecodedNulls(builder(), decoded); +} + +llvm::Value* HashAggrJitCodegen::isDecodedNull( + llvm::Value* nulls, + llvm::Value* row) const { + return ::bytedance::bolt::jit::isDecodedNull(builder(), nulls, row); } llvm::Value* HashAggrJitCodegen::isAccumulatorNull( @@ -505,18 +582,20 @@ llvm::Value* HashAggrJitCodegen::loadDecodedRowField( const auto name = decodedRowFieldFunction(kind); BOLT_CHECK( !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); + auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); return builder().CreateCall( - module_.getFunction(name), {decoded, row, builder().getInt32(field)}); + module_.getFunction(name), {decodedVector, row, builder().getInt32(field)}); } llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( llvm::Value* decoded, llvm::Value* row, int32_t field) const { + auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); return builder().CreateICmpNE( builder().CreateCall( module_.getFunction("jit_GetDecodedRowFieldIsNull"), - {decoded, row, builder().getInt32(field)}), + {decodedVector, row, builder().getInt32(field)}), builder().getInt8(0)); } @@ -695,9 +774,15 @@ bool genAddDenseIR( auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); if (checkInputNulls && !slot.countStar) { - auto* isNull = builder.CreateICmpNE( - builder.CreateCall(module.getFunction("jit_GetDecodedIsNull"), {decoded, row}), - builder.getInt8(0)); + auto* nulls = codegen.loadDecodedNulls(decoded); + auto* nullCheckBlock = + llvm::BasicBlock::Create(context, "slot_null_check", func, end); + auto* hasNulls = builder.CreateICmpNE( + nulls, llvm::ConstantPointerNull::get(i8PtrTy)); + builder.CreateCondBr(hasNulls, nullCheckBlock, updateBlock); + + builder.SetInsertPoint(nullCheckBlock); + auto* isNull = codegen.isDecodedNull(nulls, row); builder.CreateCondBr(isNull, nextBlock, updateBlock); } else { builder.CreateBr(updateBlock); diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 027462e73..578ed0e03 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -37,6 +37,28 @@ struct JitDecimalAvgState { struct HashAggrJitSlot; struct HashAggrJitExtractTarget; +// Runtime input descriptor consumed by JIT add_dense functions. +// GroupingSet prepares one descriptor per aggregate input for each batch by +// decoding the original vector into a flat/constant base plus a single indices +// mapping. This keeps generated IR independent of the batch's original vector +// encoding (flat/dictionary/constant) while allowing the hot loop to load +// values directly instead of calling jit_GetDecodedValue* helpers per row. +struct HashAggrJitDecodedInput { + const void* values{nullptr}; + // Always points to a top-level-row -> base-row mapping. For flat inputs this + // is a consecutive mapping; for constant inputs it maps every row to the + // constant value index. + const int32_t* indices{nullptr}; + // Top-level nulls. If non-null, bit 'row' indicates whether the input row is + // null. This is intentionally row-based rather than base-index-based to keep + // generated IR independent of dictionary/null wrapping details. + const uint64_t* nulls{nullptr}; + // Original DecodedVector pointer. Raw single-value inputs use the descriptor + // fields above directly; intermediate ROW inputs still use row-field helper + // APIs and therefore need the DecodedVector object. + const void* decodedVector{nullptr}; +}; + struct HashAggrJitPlanContext { bool isRawInput{false}; bool isPartialOutput{false}; @@ -163,6 +185,8 @@ class HashAggrJitCodegen { llvm::Value* decoded, llvm::Value* row, const HashAggrJitSlot& slot) const; + llvm::Value* loadDecodedNulls(llvm::Value* decoded) const; + llvm::Value* isDecodedNull(llvm::Value* nulls, llvm::Value* row) const; llvm::Value* isAccumulatorNull( llvm::Value* group, const HashAggrJitSlot& slot) const; diff --git a/bolt/vector/DecodedVector.h b/bolt/vector/DecodedVector.h index fa0f7ecf3..b39145d1a 100644 --- a/bolt/vector/DecodedVector.h +++ b/bolt/vector/DecodedVector.h @@ -151,6 +151,10 @@ class DecodedVector { return reinterpret_cast(data_); } + const void* dataAsVoid() const { + return data_; + } + /// Returns the raw nulls buffer for the base vector combined with nulls found /// in dictionary wrappings. May return nullptr if there are no nulls. Use /// top-level row numbers to access individual null flags, e.g. diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md new file mode 100644 index 000000000..ed4379c58 --- /dev/null +++ b/doc/hashaggr-jit-benchmark.md @@ -0,0 +1,205 @@ +# HashAggr JIT 性能评测报告 + +## 1. 测试环境与方法 + +- **构建**:`Release` + spark 开关(`spark_compatible=True / enable_testutil=True / + skip_test=False`),对齐 `make release_spark_with_test`;benchmark 单独 + `BOLT_BUILD_BENCHMARKS=ON`,未启用 `enable_perf`(gperftools 源码下载超时,folly + benchmark 不依赖它)。 +- **benchmark**:`bolt/exec/benchmarks/HashAggrJitBenchmark.cpp`,目标 + `bolt_hashaggr_jit_benchmark`。覆盖 sum/avg/min/count(width 4/8/16/32)、 + merge(partial+final)、decimal sum/avg、double min/max、partial extract。 +- **数据规模**:每用例 20 batch × 10000 行。 +- **关键控制**: + - JIT 模块为进程级 LRU 全局缓存,预热后**每个 JIT 函数仅编译一次**(已用 VLOG 验证 + 每个函数名 compile 次数 = 1),编译开销不计入迭代。 + - 两条路径都先 warm-up 再计时;热路径调试日志默认静默(已降级为 `VLOG(1)`)。 + - speedup = nojit / jit,**> 1 表示 JIT 更快**。 + +运行命令: + +```bash +# 低基数(聚合计算密集) +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --hashaggr_jit_benchmark_batches=20 --hashaggr_jit_benchmark_groups=100 + +# 高基数(哈希探测密集) +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --hashaggr_jit_benchmark_batches=20 --hashaggr_jit_benchmark_groups=10000 +``` + +## 2. 低基数结果(groups=100) + +| 聚合 | width4 | width8 | width16 | width32 | +|------|--------|--------|---------|---------| +| **count**(single) | **1.14x** | **1.27x** | **1.25x** | **1.33x** | +| **count**(merge) | **1.15x** | **1.23x** | **1.19x** | **1.34x** | +| sum(single) | 0.46x | 0.38x | 0.38x | 0.34x | +| sum(merge) | 0.47x | 0.40x | 0.37x | 0.37x | +| avg(single) | 0.52x | 0.46x | 0.44x | 0.45x | +| avg(merge) | 0.54x | 0.49x | 0.45x | 0.45x | +| min(single) | 0.49x | 0.42x | 0.38x | 0.39x | +| min(merge) | 0.50x | 0.43x | 0.39x | 0.40x | + +其他(width8):decimal_sum **0.40x** · decimal_avg **0.75x** · double_min 0.57x · +double_max 0.55x · partial_avg_extract 0.82x · partial_sum_extract 0.84x + +> 38 个用例中:JIT 更快 8 个(全部是 count),更慢 30 个。groups=10 对照趋势一致(误差 < 5%)。 + +## 3. 高基数结果(groups=10000) + +| 聚合 | width4 | width8 | width16 | width32 | +|------|--------|--------|---------|---------| +| **count**(single) | 1.08x | **1.44x** | **1.59x** | **1.60x** | +| **count**(merge) | 0.94x | **1.14x** | **1.22x** | **1.29x** | +| sum(single) | 0.66x | 0.71x | 0.74x | 0.71x | +| sum(merge) | 0.68x | 0.71x | 0.72x | 0.73x | +| avg(single) | 0.80x | 0.79x | 0.84x | 0.80x | +| avg(merge) | 0.52x | 0.49x | 0.51x | 0.52x | +| min(single) | 0.62x | 0.57x | 0.62x | 0.66x | +| min(merge) | 0.68x | 0.60x | 0.64x | 0.64x | + +其他(width8):decimal_sum 0.88x · decimal_avg 0.88x · double_min 0.75x · +double_max 0.63x · partial_avg_extract 0.74x · partial_sum_extract 0.84x + +## 4. 关键发现 + +1. **只有 count 稳定正收益**:低基数 1.14–1.34x,高基数最高 1.60x,且随 fuse 宽度增大而提升。 + count 的 accumulator 最简单,融合循环省下的逐聚合函数调用/分支开销占主导。 +2. **sum/avg/min/max/decimal 在 JIT 下更慢,且基数越低越慢**:sum 从高基数 0.71x 跌到 + 低基数 0.34–0.46x。 +3. **瓶颈在 JIT add 逐行路径,而非哈希探测**:groups=10 与 groups=100 的 JIT 绝对耗时 + 几乎相同(如 width8_sum jit ≈ 5.0ms 两者一致),说明耗时与组数无关、只与行数相关—— + 即**每行 add 成本** JIT 高于非 JIT 的向量化路径。这正是“低基数本应让 JIT 更受益”的 + 预期被反转的根本原因。 +4. **decimal_avg(0.75x) 优于 decimal_sum(0.40x)**:decimal_avg final 走非 JIT(spark + rescale 复杂逻辑),反而拖累较小,侧面印证当前 JIT 计算路径偏慢。 + +## 5. 结论与建议 + +- **现状**:HashAggr JIT 当前仅对 **count** 类(轻 accumulator、宽融合)有明确收益; + sum/avg/min/max/decimal 的 JIT 计算路径尚慢于现有向量化实现,**不建议默认开启**这些 + 聚合的 JIT。 +- **根因**:JIT add 内核的输入读取退化为**逐行外部 C 函数调用**,丧失了内联与向量化 + (详见第 6 章 perf 定位)。 +- **后续可做**: + 1. 把输入读取从 `jit_GetDecodedValue*` 外部调用改为 **JIT 内联**(直接对 flat/identity + 映射的 raw buffer 做 GEP+load),让 LLVM 能向量化取值-累加循环; + 2. 按聚合类型设白名单(先只对 count 默认启用 JIT); + 3. 对 avg-merge 的重路径专门优化。 + +## 6. perf 定位:sum/avg add 内核瓶颈 + +环境:`perf`(linux-tools-5.15)+ `perf_event_paranoid=1`,`-F 2999 --call-graph dwarf`, +对 `width16_sum`(fuse=16,groups=100)single-aggregation 的 JIT / 非 JIT 两条路径分别采样。 + +### 6.1 热点符号对比(self time,同一工作负载) + +| 项 | JIT 路径 | 非 JIT 路径 | +|----|----------|-------------| +| 输入取值 | `jit_GetDecodedValueI64`(外部调用,逐行)**45.4%** | `jit_GetDecodedValueI64` 仅 **0.45%** | +| 累加内核 | `[JIT]` 匿名生成码合计 **~25%** | `SumAggregateBase::addRawInput`(内联模板)**52.0%** | +| 哈希探测 | `arrayGroupProbe` 0.9% | `arrayGroupProbe` 2.2% | + +### 6.2 根因分析 + +JIT add 内核的逐行循环(`HashAggrJit.cpp:685` 的 `genAddDenseIR`)对**每行每列**都生成一次 +`CreateCall(jit_GetDecodedValueI64, {decoded, row})`(取值封装见 `loadDecodedValue` +`HashAggrJit.cpp:428`、helper 实现见 `RowContainer.cpp:1724`)。其代价: + +1. **不可内联的跨边界调用**:每行付出 call/ret + 调用约定下的寄存器溢出; +2. `DecodedVector::valueAt(index)` 内部还要判断 identity-mapping、做 indices 间接寻址; +3. **阻断向量化**:取值-累加循环因夹着 opaque 外部调用,LLVM 无法做 SIMD/循环展开。 + +而非 JIT 路径走 `SumAggregateBase::addRawInput`,整批输入在编译期类型已知、`DecodedVector` +raw buffer 被**内联顺序读取**并可向量化,因此 `jit_GetDecodedValueI64` 在该路径几乎不出现(0.45%)。 + +**count 为何不受影响**:count(`countStar` 或仅计数)不读取输入值,在 add 内核里跳过取值与 +null 检查(`HashAggrJit.cpp:697`),故没有 `jit_GetDecodedValue*` 开销,融合循环的省调用收益得以体现。 + +### 6.3 优化方向 + +最高优先级是**消除逐行外部取值调用**:对 flat / identity-mapping 的输入,在 JIT 内核里直接拿到 +`DecodedVector` 的 `data()` 基址,用 `GEP + load`(dictionary 映射则内联 indices 间接寻址)替换 +`jit_GetDecodedValue*` call,使整段取值-累加可被 LLVM 向量化——预期能把 sum/avg/min/max 的 JIT +路径从当前 0.4–0.8x 拉回到 ≥1x。 + +## 7. Direct Decoded Descriptor 优化验证 + +### 7.1 优化内容 + +本轮优化按第 6.3 节方向实现:`GroupingSet` 在每个 batch 为每个聚合输入准备一个轻量 descriptor, +JIT add 内核不再对 raw 单值输入逐行调用 `jit_GetDecodedValue*` / `jit_GetDecodedIsNull`,而是在 IR +内直接读取: + +1. `values`:decoded 后 base vector 的 raw values 基址; +2. `indices`:top-level row -> base row 的映射。flat 为 identity mapping,dictionary / constant 也由 + `DecodedVector` 统一展开为同一套映射; +3. `nulls`:top-level row null bitmap。若为 null,IR 直接跳过 null check; +4. `decodedVector`:保留原始 `DecodedVector*`,仅用于 intermediate ROW merge 的 row-field helper。 + +这样同一段 JIT IR 可以同时覆盖 flat / dictionary / constant 输入编码,不需要按 batch encoding 重新 +codegen;热循环中的普通数值读取变为 `index = indices[row]` + `values[index]`。 + +### 7.2 对比方法 + +为了衡量本次优化本身的收益,分别构建并运行了两版同一 benchmark: + +- **baseline/helper-call 版本**:原实现,每行每列调用 `jit_GetDecodedValue*` / `jit_GetDecodedIsNull`; +- **optimized/direct-descriptor 版本**:当前实现,IR 内直接读取 descriptor。 + +运行命令: + +```bash +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=5 \ + --bm_max_secs=3 \ + --bm_regex='(width8_(sum|avg|min|count|double_min|double_max)_(nojit|jit)|width8_decimal_(sum|avg)_(nojit|jit))' +``` + +测试数据规模仍为默认 20 batch × 10000 行,groups=10000,width=8。 + +### 7.3 JIT 自身优化收益 + +下表只比较两版 JIT 路径: + +| case | helper-call JIT | direct-descriptor JIT | JIT 优化收益 | +|------|----------------:|----------------------:|-------------:| +| width8_sum | 6.78ms | 3.94ms | **快 41.9%** | +| width8_avg | 6.90ms | 4.69ms | **快 32.0%** | +| width8_min | 6.71ms | 4.05ms | **快 39.6%** | +| width8_count | 3.00ms | 2.96ms | 快 1.3% | +| width8_decimal_sum | 13.27ms | 9.52ms | **快 28.3%** | +| width8_decimal_avg | 18.08ms | 14.47ms | **快 20.0%** | +| width8_double_min | 6.77ms | 4.81ms | **快 29.0%** | +| width8_double_max | 6.77ms | 4.55ms | **快 32.8%** | + +结论:消除 `jit_GetDecodedValue*` 外部 helper call 后,所有需要读取输入值的聚合都有明显收益, +幅度约 **20%–42%**。`count` 基本不读取 input value,因此收益很小,符合预期。 + +### 7.4 优化后 JIT vs no-JIT + +下表比较当前 direct-descriptor JIT 与非 JIT 路径: + +| case | no-JIT | direct-descriptor JIT | speedup = nojit / jit | +|------|-------:|----------------------:|----------------------:| +| width8_sum | 4.64ms | 3.94ms | **1.18x** | +| width8_avg | 5.29ms | 4.69ms | **1.13x** | +| width8_min | 3.75ms | 4.05ms | 0.93x | +| width8_count | 4.30ms | 2.96ms | **1.45x** | +| width8_decimal_sum | 11.96ms | 9.52ms | **1.26x** | +| width8_decimal_avg | 16.09ms | 14.47ms | **1.11x** | +| width8_double_min | 5.06ms | 4.81ms | **1.05x** | +| width8_double_max | 4.21ms | 4.55ms | 0.93x | + +优化前,sum / avg / min / decimal / double min/max 的 JIT 路径大多慢于 no-JIT;优化后,sum、avg、 +count、decimal_sum、decimal_avg、double_min 已经变为正收益。`width8_min` 和 `width8_double_max` 仍略慢, +后续需要继续看 min/max accumulator null 判断、compare 分支以及 NaN 处理逻辑。 + +### 7.5 更新后的结论 + +- 第 6 章定位的主要瓶颈(逐行不可内联 helper call)已被验证:去掉 helper call 后,需要读输入值的 + JIT 聚合普遍获得 **20%–42%** 的 JIT 内核收益。 +- HashAggr JIT 不再只有 count 有收益;在当前 width8 / groups=10000 场景下,sum、avg、decimal、 + double_min 也已经超过 no-JIT。 +- 剩余负收益集中在 min/max 类 case,下一步优化重点应转向比较更新逻辑本身,而不是 decoded value 读取。 From e360bb89ac5e63bd40f1de0e98b63d061db5b906 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 8 Jun 2026 15:40:16 +0800 Subject: [PATCH 24/98] add performance analysis after helper function opt --- doc/hashaggr-jit-benchmark.md | 126 ++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index ed4379c58..112014102 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -203,3 +203,129 @@ count、decimal_sum、decimal_avg、double_min 已经变为正收益。`width8_m - HashAggr JIT 不再只有 count 有收益;在当前 width8 / groups=10000 场景下,sum、avg、decimal、 double_min 也已经超过 no-JIT。 - 剩余负收益集中在 min/max 类 case,下一步优化重点应转向比较更新逻辑本身,而不是 decoded value 读取。 + +## 8. Direct Descriptor 后的最新 perf 定位 + +### 8.1 为什么仍远低于 multi_sum POC 预期 + +multi_sum POC 的核心收益假设是:把 `sum(c1)..sum(cN)` 合并后,可以显著减少重复的 group/hash lookup, +并让 `NumArgs` 在编译期已知,从而获得 loop unrolling。该假设对 POC 成立,但和当前 Bolt +HashAggregation 的生产路径并不完全等价。 + +当前 `GroupingSet::addInputForActiveRows` 中,hash/group probe 在所有 aggregate 之前统一执行一次: +`prepareForGroupProbe` / `groupProbe` 先生成 `groups = lookup_->hits.data()`,之后才进入聚合函数循环 +或 JIT chunk 执行。因此 **no-JIT 的多个 separate sums 已经共享同一次 hash lookup**,JIT 并不能像 POC +那样再节省 7 次或 15 次 hash lookup。JIT 当前主要节省的是每个 aggregate 独立 `addRawInput` 的函数 +调度、decoded 读取和多次遍历 rows 的开销。 + +另外,benchmark 使用 `AssertQueryBuilder(...).copyResults()` 测的是完整查询路径,不是纯 add kernel: +它还包含 input hash/vector encoding、RowContainer 新 group 初始化、结果 extract、结果 RowVector copy、 +task/benchmark 框架等共同开销。消除 `jit_GetDecodedValue*` 后,add kernel 已明显变快,但完整查询的 +Amdahl 上限被这些公共开销压低。 + +相关代码位置: + +- 单次 group probe:`bolt/exec/GroupingSet.cpp:344` +- probe 后统一进入 JIT chunk / aggregate function add:`bolt/exec/GroupingSet.cpp:375` +- JIT chunk 执行后 no-JIT aggregate 被跳过:`bolt/exec/GroupingSet.cpp:381` + +### 8.2 perf 方法 + +由于当前机器上硬件 counter(cycles/cache-misses/L1/dTLB 等)不可用,本轮使用 software +`cpu-clock` 采样;`/usr/bin/perf` wrapper 找不到匹配 5.4 内核的 perf binary,实际使用 +`/usr/lib/linux-tools-5.15.0-160/perf`。 + +为了减少 benchmark 初始化和 warm-up 对结果的污染,使用较大的 `--bm_min_iters`,让被测 case 的计时 +阶段占主导。代表性命令: + +```bash +/usr/lib/linux-tools-5.15.0-160/perf record -F 2999 \ + -o /tmp/bolt-width16-sum-jit-long.perf.data -- \ + ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=1500 --bm_max_secs=30 --bm_regex='^width16_sum_jit$' + +/usr/lib/linux-tools-5.15.0-160/perf record -F 2999 \ + -o /tmp/bolt-width16-sum-nojit-long.perf.data -- \ + ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=1500 --bm_max_secs=30 --bm_regex='^width16_sum_nojit$' +``` + +### 8.3 最新 sum benchmark 结果 + +在 direct-descriptor 优化后,sum 的完整查询收益如下: + +| case | no-JIT | JIT | speedup = nojit / jit | +|------|-------:|----:|----------------------:| +| width4_sum | 2.85ms | 2.73ms | **1.04x** | +| width8_sum | 4.68ms | 4.15ms | **1.13x** | +| width16_sum | 8.92ms | 7.48ms | **1.19x** | +| width32_sum | 17.05ms | 14.84ms | **1.15x** | + +可以看到趋势与 POC 一致:中等宽度有收益;但收益幅度只有约 4%–19%,远低于 POC 中 8/16 列 +约 42%–43% 的提升。根因是当前生产路径的 no-JIT baseline 已共享 hash probe,且完整查询包含较多 +JIT 无法消除的公共开销。 + +### 8.4 width8_sum perf 热点 + +长时间采样下的 self-time 分类如下(`cpu-clock` samples): + +| 分类 | JIT | no-JIT | 说明 | +|------|----:|-------:|------| +| add kernel | **32.72%**(JIT generated add_dense) | **45.48%**(`SumAggregateBase::addRawInput`) | JIT add 已明显少于 no-JIT add | +| hash/vector encoding | 4.15% | 4.67% | 双方共同开销 | +| hash probe | 2.27% | 2.33% | 双方都只 probe 一次,JIT 不再有 POC 中的“省多次 lookup”收益 | +| result/input copy | 3.55% | 2.74% | 完整 `copyResults` 路径成本 | +| RowContainer new/store/init | 1.52% | 2.44% | 新 group / row storage 成本 | +| JIT extract setter | **4.35%** | 0.10% | JIT extract 仍调用 `jit_HashAggrSetFlatI64` helper | +| dynamic_cast/type dispatch | **6.62%** | 0.21% | JIT 路径额外的结果/类型处理开销 | + +top symbols 中 JIT 路径最大热点已经从 `jit_GetDecodedValueI64` 迁移到 `[JIT]` 生成码本身; +`jit_GetDecodedValue*` 不再是主热点,说明第 7 章的 direct descriptor 已生效。 + +### 8.5 width16_sum perf 热点 + +width16 下趋势更明显: + +| 分类 | JIT | no-JIT | 说明 | +|------|----:|-------:|------| +| add kernel | **35.51%**(JIT generated add_dense) | **53.87%**(`SumAggregateBase::addRawInput`) | JIT kernel 节省明显 | +| aggregate init | 0.02% | 4.99% | JIT fused init 基本消除了 per-aggregate init 热点 | +| hash/vector encoding | 1.69% | 2.46% | 共同开销 | +| hash probe | 1.90% | 1.28% | 共同开销;采样误差下同量级 | +| result/input copy | 4.24% | 3.82% | 完整查询共同成本 | +| JIT extract setter | **6.37%** | 0.08% | JIT extract helper 成为新热点之一 | +| dynamic_cast/type dispatch | **7.02%** | 0.17% | JIT 路径额外成本,抵消一部分 add kernel 收益 | + +从绝对耗时估算,JIT add kernel 已从 no-JIT 的约 4.9ms(`8.92ms * 53.87%`)降到约 2.7ms +(`7.48ms * 35.51%`)。也就是说 add kernel 本身接近 **1.8x**,但完整查询最终只有 **1.19x**, +因为剩余时间被 probe、encoding、output materialization、copy 和 JIT extract helper/type dispatch 稀释。 + +### 8.6 最新瓶颈排序 + +1. **JIT 已无法再通过“少做 hash lookup”获得 POC 级收益**:Bolt baseline 本身已经 one probe for all + aggregates,这是和 multi_sum POC 最大的结构差异。 +2. **JIT add_dense 生成码仍是最大热点**:direct descriptor 去掉 helper call 后,热点回到真正的 + scalar RMW 聚合内核。当前每个 slot 每行仍要做 `indices[row]`、`values[index]`、accumulator null bit + clear、old accumulator load、add、store;这些操作围绕 `groups[row]` 间接指针,LLVM 很难 SIMD 化。 +3. **JIT extract 仍有 helper/type-dispatch 开销**:`jit_HashAggrSetFlatI64` 和 `__dynamic_cast` 在 JIT + 路径合计约 10%–13%,这是 direct descriptor 后的新显性瓶颈。no-JIT extract 使用 aggregate 自身的 + typed extract,开销低得多。 +4. **完整 query benchmark 的公共成本很高**:hash/vector encoding、RowContainer、result copy、task 框架等 + 不随 add kernel 优化而下降,限制最终端到端 speedup。 + +### 8.7 后续调优建议 + +1. **优化 JIT extract**:像 add path 一样为 extract 也传入 output descriptor(raw values/nulls),在 IR + 里直接写 FlatVector buffer,替换 `jit_HashAggrSetFlatI64` helper,并尽量避免 `dynamic_cast`。 +2. **增加 flat/no-null 快路径**:当前为了同时支持 flat/dictionary/constant,所有输入都走 `indices[row]`。 + 可以保持同一份 IR 兼容多 encoding,但在 loop preheader 根据 descriptor 判断 `indices` 是否 identity, + 分支到 flat 直读 `values[row]` 的 loop;dictionary/constant 再走 mapped loop。这样不需要按 batch + encoding 重新 codegen,但可以让常见 flat case 少一次 indices load。 +3. **消除 no-null 场景下的 per-row accumulator null clear**:sum 当前每行每 slot 都执行 + `clearAccumulatorNull`。对于 input 确认无 null 的 batch,可以考虑 batch-level 或 new-group-level 地清 + accumulator null,避免每行重复写 null bitmap。 +4. **区分纯 add kernel benchmark 与完整 query benchmark**:POC 结论更接近 add kernel 层收益;生产端到端 + 收益需要单独扣除 hash probe/output/copy 等公共成本。后续 benchmark 可以补一个只测 `GroupingSet` add + 的 microbenchmark,避免 `copyResults` 稀释定位。 +5. **继续限制 fuse width 的甜点区间**:当前 width16/32 仍有收益,但并未出现 POC 的巨大收益。考虑先保持 + `maxFuseWidth=16` 或最多 32;更宽时需要结合 cache/TLB 数据重新评估。 From 862e32d01777e185e54ae4c279c3c54942acb9de Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 8 Jun 2026 17:33:14 +0800 Subject: [PATCH 25/98] enable jit symbols in perf --- bolt/jit/ThrustJITv2.cpp | 71 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/bolt/jit/ThrustJITv2.cpp b/bolt/jit/ThrustJITv2.cpp index 54d8e4a7f..1adae7e6f 100644 --- a/bolt/jit/ThrustJITv2.cpp +++ b/bolt/jit/ThrustJITv2.cpp @@ -21,16 +21,81 @@ #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/Object/SymbolSize.h" #include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Process.h" #include "llvm/Support/TargetSelect.h" #include #include +#include +#include +#include #include +#include #include namespace bytedance::bolt::jit { +namespace { + +// Returns whether perf-map emission is enabled. The conan LLVM package is built +// without LLVM_USE_PERF, so llvm::JITEventListener::createPerfJITEventListener() +// is a no-op stub that returns nullptr. We therefore emit the perf map file +// ourselves. Controlled by the BOLT_JIT_PERF env var to avoid file IO overhead +// in normal runs. +bool perfMapEnabled() { + static const bool enabled = std::getenv("BOLT_JIT_PERF") != nullptr; + return enabled; +} + +// Appends symbols of a freshly loaded JIT object to /tmp/perf-.map so that +// `perf report` can resolve JIT-generated machine code to function names. +// Format per line: " ". +void appendPerfMap( + const llvm::object::ObjectFile& obj, + const llvm::RuntimeDyld::LoadedObjectInfo& loadedInfo) { + // Use the relocated/loaded object so symbol addresses are the runtime ones. + auto debugObjOwner = loadedInfo.getObjectForDebug(obj); + const llvm::object::ObjectFile& debugObj = *debugObjOwner.getBinary(); + + static std::mutex perfMapMutex; + std::lock_guard guard(perfMapMutex); + + std::string path = + "/tmp/perf-" + std::to_string(llvm::sys::Process::getProcessId()) + + ".map"; + std::FILE* file = std::fopen(path.c_str(), "a"); + if (file == nullptr) { + return; + } + + for (const auto& [symbol, size] : + llvm::object::computeSymbolSizes(debugObj)) { + auto typeOr = symbol.getType(); + if (!typeOr || *typeOr != llvm::object::SymbolRef::ST_Function) { + continue; + } + auto addrOr = symbol.getAddress(); + auto nameOr = symbol.getName(); + if (!addrOr || !nameOr || size == 0) { + llvm::consumeError(addrOr.takeError()); + llvm::consumeError(nameOr.takeError()); + continue; + } + std::fprintf( + file, + "%llx %llx %.*s\n", + static_cast(*addrOr), + static_cast(size), + static_cast(nameOr->size()), + nameOr->data()); + } + std::fclose(file); +} + +} // namespace + llvm::Expected> ThrustJITv2::Create() { static std::once_flag llvmTargetInitialized; std::call_once(llvmTargetInitialized, []() { @@ -59,7 +124,11 @@ llvm::Expected> ThrustJITv2::Create() { [tracker]( llvm::orc::MaterializationResponsibility& mr, const llvm::object::ObjectFile& obj, - const llvm::RuntimeDyld::LoadedObjectInfo&) { + const llvm::RuntimeDyld::LoadedObjectInfo& + loadedInfo) { + if (perfMapEnabled()) { + appendPerfMap(obj, loadedInfo); + } llvm::orc::ResourceKey resourceKey = 0; if (auto err = mr.withResourceKeyDo( [&](llvm::orc::ResourceKey key) { From a24ff035a00e99c133a66dbb0e073fd413853949 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 8 Jun 2026 21:48:16 +0800 Subject: [PATCH 26/98] =?UTF-8?q?=E4=BC=98=E5=8C=96extractxxx=E4=B8=AD?= =?UTF-8?q?=E7=9A=84helper=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bolt/exec/GroupingSet.cpp | 38 ++++- bolt/exec/GroupingSet.h | 1 + bolt/jit/aggregation/HashAggrJit.cpp | 87 +++++++++- bolt/jit/aggregation/HashAggrJit.h | 10 ++ doc/hashaggr-jit-benchmark.md | 229 +++++++++++++++++++++++++++ 5 files changed, 359 insertions(+), 6 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 7f83c8ab1..381d28edb 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -42,6 +42,7 @@ #endif #include "bolt/type/Type.h" #include "bolt/vector/ComplexVector.h" +#include "bolt/vector/FlatVector.h" using bytedance::bolt::common::testutil::TestValue; namespace bytedance::bolt::exec { @@ -68,6 +69,29 @@ std::string hashAggrJitTypeName(const TypePtr& type) { return type == nullptr ? "null" : type->toString(); } +void* hashAggrJitRawOutputValues( + BaseVector* vector, + jit::HashAggrJitValueKind kind) { + switch (kind) { + case jit::HashAggrJitValueKind::Int8: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int16: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int32: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Int64: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Float: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Double: + return vector->asUnchecked>()->mutableRawValues(); + case jit::HashAggrJitValueKind::Bool: + case jit::HashAggrJitValueKind::Int128: + return nullptr; + } + return nullptr; +} + std::string hashAggrJitSlotDebugString( const jit::HashAggrJitSlot& slot, const AggregateInfo* aggregate = nullptr) { @@ -1110,6 +1134,7 @@ void GroupingSet::runHashAggrJitExtractChunks( continue; } const auto numSlots = chunk.slots().size(); + hashAggrJitOutputs_.assign(numSlots, jit::HashAggrJitOutput{}); hashAggrJitResultPtrs_.assign(numSlots, nullptr); bool canRunChunk = true; std::string skipReason; @@ -1132,7 +1157,18 @@ void GroupingSet::runHashAggrJitExtractChunks( skipReason = "unexpected result vector encoding"; break; } - hashAggrJitResultPtrs_[slotIndex] = reinterpret_cast(aggregateVector.get()); + // Prepare stable raw output pointers after resizing. The JIT extract + // function still receives char** for ABI compatibility, but each element + // now points to HashAggrJitOutput rather than BaseVector directly. + aggregateVector->resize(groups.size()); + hashAggrJitOutputs_[slotIndex].vector = aggregateVector.get(); + if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { + hashAggrJitOutputs_[slotIndex].values = + hashAggrJitRawOutputValues(aggregateVector.get(), slot.accumulatorKind); + hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); + } + hashAggrJitResultPtrs_[slotIndex] = + reinterpret_cast(&hashAggrJitOutputs_[slotIndex]); } if (!canRunChunk) { VLOG(1) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index f8ca2c43a..c38b143d7 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -468,6 +468,7 @@ class GroupingSet { std::vector hashAggrJitInputVectors_; std::vector hashAggrJitDecodedPtrs_; std::vector hashAggrJitNewGroups_; + std::vector hashAggrJitOutputs_; std::vector hashAggrJitResultPtrs_; #endif diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 96f8a1ae4..52ba563dd 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -339,6 +339,65 @@ bool isFloatKind(HashAggrJitValueKind kind) { kind == HashAggrJitValueKind::Double; } +bool supportsRawFlatOutput(HashAggrJitValueKind kind) { + switch (kind) { + case HashAggrJitValueKind::Int8: + case HashAggrJitValueKind::Int16: + case HashAggrJitValueKind::Int32: + case HashAggrJitValueKind::Int64: + case HashAggrJitValueKind::Float: + case HashAggrJitValueKind::Double: + return true; + case HashAggrJitValueKind::Bool: + case HashAggrJitValueKind::Int128: + return false; + } + return false; +} + +llvm::Value* loadOutputValues(llvm::IRBuilder<>& builder, llvm::Value* output) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* valuesPtrPtr = builder.CreatePointerCast(output, i8PtrTy->getPointerTo()); + return builder.CreateLoad(i8PtrTy, valuesPtrPtr, "output_values"); +} + +llvm::Value* loadOutputNulls(llvm::IRBuilder<>& builder, llvm::Value* output) { + auto* i64Ty = builder.getInt64Ty(); + auto* nullsAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), output, static_cast(sizeof(void*))); + auto* nullsPtrPtr = + builder.CreatePointerCast(nullsAddr, i64Ty->getPointerTo()->getPointerTo()); + return builder.CreateLoad(i64Ty->getPointerTo(), nullsPtrPtr, "output_nulls"); +} + +llvm::Value* loadOutputVector(llvm::IRBuilder<>& builder, llvm::Value* output) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* vectorAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), output, static_cast(2 * sizeof(void*))); + auto* vectorPtrPtr = builder.CreatePointerCast(vectorAddr, i8PtrTy->getPointerTo()); + return builder.CreateLoad(i8PtrTy, vectorPtrPtr, "output_vector"); +} + +void emitOutputNullBit( + llvm::IRBuilder<>& builder, + llvm::Value* nulls, + llvm::Value* row, + llvm::Value* isNull) { + auto* i64Ty = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(row, builder.getInt32(63)); + auto* wordAddr = builder.CreateInBoundsGEP( + i64Ty, nulls, builder.CreateZExt(wordIndex, builder.getInt64Ty())); + auto* word = builder.CreateLoad(i64Ty, wordAddr); + auto* mask = builder.CreateShl( + builder.getInt64(1), builder.CreateZExt(bitIndex, builder.getInt64Ty())); + auto* notNullWord = builder.CreateOr(word, mask); + auto* nullWord = builder.CreateAnd(word, builder.CreateNot(mask)); + auto* isNullBool = builder.CreateICmpNE(isNull, builder.getInt8(0)); + builder.CreateStore( + builder.CreateSelect(isNullBool, nullWord, notNullWord), wordAddr); +} + llvm::LoadInst* loadValue( llvm::IRBuilder<>& builder, llvm::Value* row, @@ -600,39 +659,55 @@ llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( } void HashAggrJitCodegen::emitFlatValue( - llvm::Value* vector, + llvm::Value* output, llvm::Value* row, HashAggrJitValueKind kind, llvm::Value* value, llvm::Value* isNull) const { + if (supportsRawFlatOutput(kind)) { + auto* type = llvmType(kind); + auto* values = ::bytedance::bolt::jit::loadOutputValues(builder(), output); + auto* typedValues = builder().CreatePointerCast(values, type->getPointerTo()); + auto* valueAddr = builder().CreateInBoundsGEP( + type, typedValues, builder().CreateZExt(row, builder().getInt64Ty())); + auto* store = builder().CreateStore(value, valueAddr); + store->setAlignment(llvm::Align(1)); + auto* nulls = ::bytedance::bolt::jit::loadOutputNulls(builder(), output); + ::bytedance::bolt::jit::emitOutputNullBit(builder(), nulls, row, isNull); + return; + } + const auto setter = setFlatValueFunction(kind); if (setter.empty()) { return; } + auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(setter), {vector, row, value, isNull}); } void HashAggrJitCodegen::resizeResultVector( - llvm::Value* vector, + llvm::Value* output, llvm::Value* size) const { + auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction("jit_HashAggrResizeVector"), {vector, size}); } void HashAggrJitCodegen::emitPartialAvgResult( - llvm::Value* vector, + llvm::Value* output, llvm::Value* row, llvm::Value* sum, llvm::Value* count, llvm::Value* isNull) const { + auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction("jit_HashAggrSetPartialAvgDouble"), {vector, row, sum, count, isNull}); } void HashAggrJitCodegen::emitDecimalSumExtract( - llvm::Value* vector, + llvm::Value* output, llvm::Value* row, llvm::Value* group, const HashAggrJitSlot& slot, @@ -641,6 +716,7 @@ void HashAggrJitCodegen::emitDecimalSumExtract( : "jit_HashAggrExtractFinalDecimalSum"; auto* longDecimal = builder().getInt8( slot.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), {vector, @@ -653,7 +729,7 @@ void HashAggrJitCodegen::emitDecimalSumExtract( } void HashAggrJitCodegen::emitDecimalAvgExtract( - llvm::Value* vector, + llvm::Value* output, llvm::Value* row, llvm::Value* group, const HashAggrJitSlot& slot, @@ -662,6 +738,7 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( : "jit_HashAggrExtractFinalDecimalAvg"; auto* longDecimal = builder().getInt8( slot.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), {vector, diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 578ed0e03..2f6c379b3 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -59,6 +59,16 @@ struct HashAggrJitDecodedInput { const void* decodedVector{nullptr}; }; +// Runtime output descriptor consumed by JIT extract functions. GroupingSet +// prepares one descriptor per aggregate output after resizing the result vector. +// Primitive flat outputs write values/null bits directly from generated IR; +// complex outputs keep using vector helper fallbacks via 'vector'. +struct HashAggrJitOutput { + void* values{nullptr}; + uint64_t* nulls{nullptr}; + void* vector{nullptr}; +}; + struct HashAggrJitPlanContext { bool isRawInput{false}; bool isPartialOutput{false}; diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 112014102..df16ca517 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -329,3 +329,232 @@ width16 下趋势更明显: 的 microbenchmark,避免 `copyResults` 稀释定位。 5. **继续限制 fuse width 的甜点区间**:当前 width16/32 仍有收益,但并未出现 POC 的巨大收益。考虑先保持 `maxFuseWidth=16` 或最多 32;更宽时需要结合 cache/TLB 数据重新评估。 + +## 9. JIT extract raw output descriptor 优化验证 + +### 9.1 优化内容 + +本轮继续优化第 8.6 节定位出的 extract 瓶颈:JIT extract 不再对普通 FLAT primitive 输出逐行调用 +`jit_HashAggrSetFlat*` helper,而是由 `GroupingSet` 为每个 aggregate output 准备 +`HashAggrJitOutput` descriptor: + +1. `values`:`FlatVector::mutableRawValues()`; +2. `nulls`:`BaseVector::mutableRawNulls()`; +3. `vector`:原始 `BaseVector*`,保留给 decimal / partial avg ROW 等复杂输出 helper fallback。 + +JIT extract IR 对 `Int8/Int16/Int32/Int64/Float/Double` 直接执行: + +```text +values[row] = value +isNull ? clear null bitmap bit : set null bitmap bit +``` + +`Bool`、`Int128/decimal`、partial avg ROW output 暂不做 raw 写,仍通过 descriptor 中的 `vector` 走原 helper。 + +### 9.2 功能与性能验证命令 + +构建: + +```bash +cmake --build --preset conan-release --target bolt_hashaggr_jit_benchmark --parallel 2 +``` + +功能覆盖(sum/avg/min/count/decimal/double min-max): + +```bash +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=3 --bm_max_secs=2 \ + --bm_regex='^width8_(sum|avg|min|count|double_min|double_max|decimal_sum|decimal_avg)_(nojit|jit)$' +``` + +sum 宽度扫描: + +```bash +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=20 --bm_max_secs=5 \ + --bm_regex='^width(4|8|16|32)_sum_(nojit|jit)$' +``` + + +### 9.3 最新 width8 结果 + +| case | no-JIT | JIT | speedup = nojit / jit | +|------|-------:|----:|----------------------:| +| width8_sum | 4.61ms | 3.55ms | **1.30x** | +| width8_avg | 5.41ms | 4.18ms | **1.29x** | +| width8_min | 3.71ms | 3.48ms | **1.07x** | +| width8_count | 4.30ms | 2.41ms | **1.78x** | +| width8_decimal_sum | 12.13ms | 9.58ms | **1.27x** | +| width8_decimal_avg | 16.49ms | 14.77ms | **1.12x** | +| width8_double_min | 4.94ms | 4.23ms | **1.17x** | +| width8_double_max | 4.21ms | 3.81ms | **1.10x** | + +对比第 7.4 节,`min` / `double_max` 已从略慢于 no-JIT 变为正收益;`sum`、`avg`、`count` 也继续提升。 + +### 9.4 最新 sum 宽度扫描 + +| case | no-JIT | JIT | speedup = nojit / jit | +|------|-------:|----:|----------------------:| +| width4_sum | 2.60ms | 2.44ms | **1.07x** | +| width8_sum | 4.65ms | 3.45ms | **1.35x** | +| width16_sum | 9.06ms | 6.06ms | **1.50x** | +| width32_sum | 17.28ms | 12.21ms | **1.42x** | + +相比第 8.3 节(extract 优化前 width16_sum 约 1.19x、width32_sum 约 1.15x),raw output descriptor 后 +宽聚合收益明显扩大,说明之前 extract helper/type-dispatch 确实抵消了大量 add_dense 的融合收益。 + + +### 9.5 详细数据 + +``` +$ ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark +============================================================================ +[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s +============================================================================ +width4_sum_nojit 2.57ms 388.70 +width4_sum_jit 2.40ms 416.48 +---------------------------------------------------------------------------- +width4_avg_nojit 3.27ms 306.05 +width4_avg_jit 2.59ms 385.86 +---------------------------------------------------------------------------- +width4_min_nojit 2.40ms 417.30 +width4_min_jit 2.42ms 413.22 +---------------------------------------------------------------------------- +width4_count_nojit 2.42ms 413.03 +width4_count_jit 1.90ms 525.58 +---------------------------------------------------------------------------- +width4_merge_sum_nojit 3.60ms 277.57 +width4_merge_sum_jit 3.29ms 303.98 +---------------------------------------------------------------------------- +width4_merge_avg_nojit 4.58ms 218.27 +width4_merge_avg_jit 7.19ms 138.99 +---------------------------------------------------------------------------- +width4_merge_min_nojit 3.42ms 292.16 +width4_merge_min_jit 3.33ms 300.62 +---------------------------------------------------------------------------- +width4_merge_count_nojit 3.37ms 296.86 +width4_merge_count_jit 2.88ms 346.72 +---------------------------------------------------------------------------- +width8_sum_nojit 4.62ms 216.54 +width8_sum_jit 3.36ms 297.77 +---------------------------------------------------------------------------- +width8_avg_nojit 5.37ms 186.22 +width8_avg_jit 4.14ms 241.70 +---------------------------------------------------------------------------- +width8_min_nojit 3.70ms 270.01 +width8_min_jit 3.50ms 285.84 +---------------------------------------------------------------------------- +width8_count_nojit 4.26ms 235.01 +width8_count_jit 2.35ms 425.31 +---------------------------------------------------------------------------- +width8_merge_sum_nojit 6.18ms 161.79 +width8_merge_sum_jit 4.60ms 217.60 +---------------------------------------------------------------------------- +width8_merge_avg_nojit 7.70ms 129.85 +width8_merge_avg_jit 12.31ms 81.22 +---------------------------------------------------------------------------- +width8_merge_min_nojit 5.31ms 188.20 +width8_merge_min_jit 4.83ms 206.90 +---------------------------------------------------------------------------- +width8_merge_count_nojit 5.72ms 174.70 +width8_merge_count_jit 3.58ms 279.62 +---------------------------------------------------------------------------- +width16_sum_nojit 9.01ms 110.95 +width16_sum_jit 5.93ms 168.53 +---------------------------------------------------------------------------- +width16_avg_nojit 10.53ms 94.95 +width16_avg_jit 7.38ms 135.53 +---------------------------------------------------------------------------- +width16_min_nojit 7.92ms 126.22 +width16_min_jit 6.24ms 160.20 +---------------------------------------------------------------------------- +width16_count_nojit 7.73ms 129.35 +width16_count_jit 3.50ms 285.49 +---------------------------------------------------------------------------- +width16_merge_sum_nojit 11.44ms 87.44 +width16_merge_sum_jit 7.58ms 131.87 +---------------------------------------------------------------------------- +width16_merge_avg_nojit 15.68ms 63.79 +width16_merge_avg_jit 23.72ms 42.16 +---------------------------------------------------------------------------- +width16_merge_min_nojit 10.21ms 97.95 +width16_merge_min_jit 7.94ms 125.98 +---------------------------------------------------------------------------- +width16_merge_count_nojit 10.10ms 98.97 +width16_merge_count_jit 5.22ms 191.41 +---------------------------------------------------------------------------- +width32_sum_nojit 17.20ms 58.13 +width32_sum_jit 12.08ms 82.76 +---------------------------------------------------------------------------- +width32_avg_nojit 19.42ms 51.48 +width32_avg_jit 15.11ms 66.20 +---------------------------------------------------------------------------- +width32_min_nojit 15.56ms 64.26 +width32_min_jit 12.53ms 79.78 +---------------------------------------------------------------------------- +width32_count_nojit 15.66ms 63.85 +width32_count_jit 7.12ms 140.37 +---------------------------------------------------------------------------- +width32_merge_sum_nojit 23.30ms 42.91 +width32_merge_sum_jit 16.24ms 61.59 +---------------------------------------------------------------------------- +width32_merge_avg_nojit 30.22ms 33.09 +width32_merge_avg_jit 47.82ms 20.91 +---------------------------------------------------------------------------- +width32_merge_min_nojit 19.79ms 50.52 +width32_merge_min_jit 15.78ms 63.37 +---------------------------------------------------------------------------- +width32_merge_count_nojit 19.32ms 51.75 +width32_merge_count_jit 10.17ms 98.30 +---------------------------------------------------------------------------- +width8_decimal_sum_nojit 12.03ms 83.13 +width8_decimal_sum_jit 9.83ms 101.71 +---------------------------------------------------------------------------- +width8_decimal_avg_nojit 16.29ms 61.38 +width8_decimal_avg_jit 14.77ms 67.70 +---------------------------------------------------------------------------- +width8_double_min_nojit 5.05ms 197.94 +width8_double_min_jit 4.16ms 240.16 +---------------------------------------------------------------------------- +width8_double_max_nojit 4.18ms 239.18 +width8_double_max_jit 3.84ms 260.33 +---------------------------------------------------------------------------- +width8_high_card_partial_avg_extract_nojit 61.78ms 16.19 +width8_high_card_partial_avg_extract_jit 80.29ms 12.46 +---------------------------------------------------------------------------- +width8_high_card_partial_sum_extract_nojit 27.36ms 36.54 +width8_high_card_partial_sum_extract_jit 23.51ms 42.54 +---------------------------------------------------------------------------- +``` + + + +### 9.5 perf 验证 + +代表性命令: + +```bash +/usr/lib/linux-tools-5.15.0-160/perf record -F 2999 \ + -o /tmp/bolt-width16-sum-jit-outputdesc.perf.data -- \ + ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=200 --bm_max_secs=8 --bm_regex='^width16_sum_jit$' + +/usr/lib/linux-tools-5.15.0-160/perf report \ + -i /tmp/bolt-width16-sum-jit-outputdesc.perf.data \ + --stdio --no-children --sort symbol --percent-limit 0 \ + | grep -E 'jit_HashAggrSetFlatI64|dynamic_cast|__dynamic|__do_dyncast|HashAggrSetFlat' +``` + +结果:`jit_HashAggrSetFlatI64` / `HashAggrSetFlat*` 不再出现在 perf report 中;`__dynamic_cast` 降到 +约 **0.27%**,`__do_dyncast` 合计约 **0.15%**。对比第 8.5 节,extract helper 与 dynamic_cast/type +dispatch 从 JIT 路径约 **13%** 的显性热点降为噪声级别。 + +### 9.6 更新后的结论 + +- direct decoded input descriptor 解决了 add_dense 的外部取值 helper;raw output descriptor 继续解决了 + extract 的 per-row setter helper / dynamic_cast。 +- 当前 width8 常见数值聚合已全部为正收益;sum 宽度扫描在 width16 达到约 **1.50x**,更接近最初 multi_sum + POC 的方向性预期。 +- 剩余瓶颈主要回到真正的 JIT 生成码、hash/vector encoding、RowContainer 和 result copy 等公共成本;后续 + 若继续优化,优先考虑 flat/no-null add_dense 快路径、减少 per-row accumulator null clear,以及拆分纯 + `GroupingSet` add microbenchmark 来单独观察 kernel 收益。 From be0e6666903863ed5f9dcdcf24a38c0d6afdea60 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Tue, 9 Jun 2026 08:34:49 +0800 Subject: [PATCH 27/98] update performance report --- doc/hashaggr-jit-benchmark.md | 138 ++++++++++++++++++++++++++++++++-- 1 file changed, 132 insertions(+), 6 deletions(-) diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index df16ca517..598172bfc 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -529,7 +529,131 @@ width8_high_card_partial_sum_extract_jit 23.51ms 42.54 -### 9.5 perf 验证 +### 9.6 当前剩余瓶颈分析 + +从完整 benchmark 结果看,direct decoded input descriptor 和 raw output descriptor 已经解决了此前最明显的 +两类 helper 开销:`jit_GetDecodedValue*` 输入读取 helper,以及 `jit_HashAggrSetFlat*` / `dynamic_cast` +输出写 helper。普通 FLAT primitive 聚合现在基本都已经转为正收益,但仍有几类结构性瓶颈。 + +#### 9.6.1 最大负收益:merge avg 的 ROW intermediate 路径 + +目前最明显的回退集中在 `merge_avg_jit`: + +| case | no-JIT | JIT | speedup = nojit / jit | +|------|-------:|----:|----------------------:| +| width4_merge_avg | 4.58ms | 7.19ms | **0.64x** | +| width8_merge_avg | 7.70ms | 12.31ms | **0.63x** | +| width16_merge_avg | 15.68ms | 23.72ms | **0.66x** | +| width32_merge_avg | 30.22ms | 47.82ms | **0.63x** | + +这个比例在不同 width 下非常稳定,说明不是 benchmark 噪声,而是路径本身还没有被优化。根因是 avg merge +的 intermediate input 是 `ROW(sum, count)`,没有完全吃到 raw decoded descriptor 优化:普通数值输入已经能 +通过 `values + indices + nulls` 在 JIT IR 中直接 load,但 ROW field 读取仍然需要类似 +`jit_GetDecodedRowFieldDouble` / `jit_GetDecodedRowFieldI64` / `jit_GetDecodedRowFieldIsNull` 的 helper 或 +DecodedVector row-field 路径。 + +partial avg extract 也印证了这个结论: + +| case | no-JIT | JIT | speedup = nojit / jit | +|------|-------:|----:|----------------------:| +| width8_high_card_partial_avg_extract | 61.78ms | 80.29ms | **0.77x** | +| width8_high_card_partial_sum_extract | 27.36ms | 23.51ms | **1.16x** | + +partial sum 输出是 FLAT,已经受益于 raw output descriptor;partial avg 输出是 ROW,目前仍走 helper fallback, +因此仍然更慢。 + +**建议**:短期可考虑禁用 `merge_avg_jit` 和 `partial_avg_extract_jit`;长期需要为 ROW input/output 增加 +descriptor,把 `sum` / `count` 两个 child vector 的 raw values/nulls 直接传给 JIT。 + +#### 9.6.2 主路径瓶颈:JIT add_dense 仍是 row-based scalar RMW loop + +普通 sum/avg/count 已经有明显正收益: + +| 聚合 | width4 | width8 | width16 | width32 | +|------|-------:|-------:|--------:|--------:| +| sum | 1.07x | 1.38x | 1.52x | 1.42x | +| avg | 1.26x | 1.30x | 1.43x | 1.29x | +| count | 1.27x | 1.81x | 2.21x | 2.20x | + +count 收益显著高于 sum/avg,因为 count 不需要读取 input value,也不需要做加法以外的复杂状态维护。sum/avg +的剩余成本主要回到 JIT 生成码本身: + +```text +group = groups[row] +index = indices[row] +value = values[index] +load accumulator +clear accumulator null bit +add / update count +store accumulator +``` + +这仍然是 row-based scalar read-modify-write loop:`groups[row]` 是间接指针访问,`indices[row]` 即使在 flat +input 下也要额外 load,accumulator 存在 RowContainer row storage 中而不是连续 columnar buffer,LLVM 很难 +做 SIMD 化。 + +**建议**:优先做 flat/no-null add_dense 快路径。在 input 是 flat identity mapping 且没有 null 时,直接生成 +`value = values[row]`,跳过 `indices[row]` 和 input null 分支。 + +#### 9.6.3 小 width 收益有限:公共固定成本占比高 + +width4 下收益明显弱于 width8/16/32: + +| case | speedup | +|------|--------:| +| width4_sum | 1.07x | +| width4_min | 0.99x | +| width4_count | 1.27x | + +width4 中可 fusion 的 aggregate 数量少,JIT 能省下的 per-aggregate dispatch / loop traversal 不多,但 descriptor +准备、JIT chunk 调用、result vector resize、output materialization、RowContainer / hash probe / copyResults 等 +完整 query 公共成本仍然存在。 + +**建议**:默认启用策略上应更偏向 width8+ 或 count/sum 这类收益稳定的 case;低 width case 需要结合实际 +query 成本谨慎启用。 + +#### 9.6.4 min/max 收益较小:compare 与 null-init 分支仍偏重 + +min/max 已经转为正收益,但弱于 sum/count: + +| case | speedup | +|------|--------:| +| width8_min | 1.06x | +| width16_min | 1.27x | +| width32_min | 1.24x | +| width8_double_min | 1.21x | +| width8_double_max | 1.09x | + +min/max 每行更新不仅要读取 input value,还要处理 accumulator 是否 null、首次 non-null 初始化、compare 分支; +double min/max 还可能受 NaN / ordering 语义影响。相比 sum 的简单加法,这些分支更难被 LLVM 优化。 + +**建议**:后续可为 no-null + accumulator initialized 场景生成更简单的 compare-only 快路径。 + +#### 9.6.5 decimal 仍受复杂 overflow/precision 逻辑限制 + +decimal 现在已经是正收益,但幅度有限: + +| case | no-JIT | JIT | speedup = nojit / jit | +|------|-------:|----:|----------------------:| +| width8_decimal_sum | 12.03ms | 9.83ms | **1.22x** | +| width8_decimal_avg | 16.29ms | 14.77ms | **1.10x** | + +decimal update/extract 仍包含 int128 accumulator、overflow state、precision/scale 检查、final extract overflow +处理以及 decimal avg rescale 等复杂逻辑,无法像 primitive sum 一样完全变成简单 raw load/store。 + +**建议**:decimal 可以继续专项优化,但优先级低于 ROW avg 路径和 flat/no-null add_dense 快路径。 + +#### 9.6.6 当前瓶颈优先级 + +1. **P0:ROW avg 路径**:`merge_avg_jit` 和 `partial_avg_extract_jit` 是目前唯一大幅负收益路径。短期禁用, + 长期做 ROW input/output descriptor。 +2. **P1:flat/no-null add_dense 快路径**:减少 `indices[row]` 间接读取和 null 分支,继续提升 sum/avg/min 主路径。 +3. **P2:减少 per-row accumulator null clear**:对于 no-null input 或 accumulator 已初始化场景,把 null clear 从 + per-row 下沉到更粗粒度。 +4. **P3:min/max compare 快路径**:减少 accumulator null/init 分支。 +5. **P4:decimal 专项优化**:拆解 overflow/precision helper,但收益优先级相对靠后。 + +### 9.7 perf 验证 代表性命令: @@ -549,12 +673,14 @@ width8_high_card_partial_sum_extract_jit 23.51ms 42.54 约 **0.27%**,`__do_dyncast` 合计约 **0.15%**。对比第 8.5 节,extract helper 与 dynamic_cast/type dispatch 从 JIT 路径约 **13%** 的显性热点降为噪声级别。 -### 9.6 更新后的结论 +### 9.8 更新后的结论 - direct decoded input descriptor 解决了 add_dense 的外部取值 helper;raw output descriptor 继续解决了 extract 的 per-row setter helper / dynamic_cast。 -- 当前 width8 常见数值聚合已全部为正收益;sum 宽度扫描在 width16 达到约 **1.50x**,更接近最初 multi_sum - POC 的方向性预期。 -- 剩余瓶颈主要回到真正的 JIT 生成码、hash/vector encoding、RowContainer 和 result copy 等公共成本;后续 - 若继续优化,优先考虑 flat/no-null add_dense 快路径、减少 per-row accumulator null clear,以及拆分纯 +- 当前 width8 常见 FLAT primitive 数值聚合已全部为正收益;sum 宽度扫描在 width16 达到约 **1.52x**,更接近 + 最初 multi_sum POC 的方向性预期。 +- 最大遗留问题是 ROW intermediate/output:`merge_avg_jit` 和 `partial_avg_extract_jit` 仍显著慢于 no-JIT,应 + 优先禁用或实现 ROW descriptor。 +- 主路径剩余瓶颈回到真正的 JIT add_dense 生成码、hash/vector encoding、RowContainer 和 result copy 等公共成本; + 后续若继续优化,优先考虑 flat/no-null add_dense 快路径、减少 per-row accumulator null clear,以及拆分纯 `GroupingSet` add microbenchmark 来单独观察 kernel 收益。 From dc9687dca2322206ab4dbc3ffea40a6b500db77b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Tue, 9 Jun 2026 22:35:10 +0800 Subject: [PATCH 28/98] optimize performance of partial_avg_extract --- bolt/exec/GroupingSet.cpp | 83 +++++++++ bolt/jit/aggregation/HashAggrJit.cpp | 134 +++++++++++++++ bolt/jit/aggregation/HashAggrJit.h | 18 +- doc/hashaggr-jit-benchmark.md | 241 ++++++++++++++++++++++++++- 4 files changed, 471 insertions(+), 5 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 381d28edb..f58c18c14 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -92,6 +92,80 @@ void* hashAggrJitRawOutputValues( return nullptr; } +const void* hashAggrJitRawInputValues( + const BaseVector* vector, + jit::HashAggrJitValueKind kind) { + switch (kind) { + case jit::HashAggrJitValueKind::Int8: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int16: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int32: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int64: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Int128: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Float: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Double: + return vector->asUnchecked>()->rawValues(); + case jit::HashAggrJitValueKind::Bool: + return nullptr; + } + return nullptr; +} + +void fillHashAggrJitRowFieldInputs( + jit::HashAggrJitDecodedInput& input, + const DecodedVector& decoded, + const jit::HashAggrJitSlot& slot) { + if (!slot.mergeInput || slot.kind != jit::HashAggrJitKind::Avg) { + return; + } + const auto* base = decoded.base(); + if (base == nullptr || base->encoding() != VectorEncoding::Simple::ROW) { + return; + } + const auto* rowVector = base->asUnchecked(); + if (rowVector->childrenSize() < 2) { + return; + } + const auto& sumVector = rowVector->childAt(0); + const auto& countVector = rowVector->childAt(1); + if (sumVector->encoding() != VectorEncoding::Simple::FLAT || + countVector->encoding() != VectorEncoding::Simple::FLAT) { + return; + } + input.rowField0Values = + hashAggrJitRawInputValues(sumVector.get(), slot.inputKind); + input.rowField0Nulls = sumVector->rawNulls(); + input.rowField1Values = + hashAggrJitRawInputValues(countVector.get(), jit::HashAggrJitValueKind::Int64); + input.rowField1Nulls = countVector->rawNulls(); +} + +void fillHashAggrJitPartialAvgOutput( + jit::HashAggrJitOutput& output, + BaseVector* vector) { + auto* rowVector = vector->asUnchecked(); + if (rowVector->childrenSize() < 2) { + return; + } + auto& sumVector = rowVector->childAt(0); + auto& countVector = rowVector->childAt(1); + if (sumVector->encoding() != VectorEncoding::Simple::FLAT || + countVector->encoding() != VectorEncoding::Simple::FLAT) { + return; + } + output.rowField0Values = + sumVector->asUnchecked>()->mutableRawValues(); + output.rowField0Nulls = sumVector->mutableRawNulls(); + output.rowField1Values = + countVector->asUnchecked>()->mutableRawValues(); + output.rowField1Nulls = countVector->mutableRawNulls(); +} + std::string hashAggrJitSlotDebugString( const jit::HashAggrJitSlot& slot, const AggregateInfo* aggregate = nullptr) { @@ -1062,6 +1136,10 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitDecoded_[slotIndex].indices(), hashAggrJitDecoded_[slotIndex].nulls(&activeRows_), &hashAggrJitDecoded_[slotIndex]}; + fillHashAggrJitRowFieldInputs( + hashAggrJitDecodedInputs_[slotIndex], + hashAggrJitDecoded_[slotIndex], + slot); inputsMayHaveNulls = inputsMayHaveNulls || hashAggrJitDecoded_[slotIndex].mayHaveNulls(); hashAggrJitDecodedPtrs_[slotIndex] = @@ -1166,6 +1244,11 @@ void GroupingSet::runHashAggrJitExtractChunks( hashAggrJitOutputs_[slotIndex].values = hashAggrJitRawOutputValues(aggregateVector.get(), slot.accumulatorKind); hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); + } else if (aggregateVector->encoding() == VectorEncoding::Simple::ROW && + slot.kind == jit::HashAggrJitKind::Avg) { + hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); + fillHashAggrJitPartialAvgOutput( + hashAggrJitOutputs_[slotIndex], aggregateVector.get()); } hashAggrJitResultPtrs_[slotIndex] = reinterpret_cast(&hashAggrJitOutputs_[slotIndex]); diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 52ba563dd..a9f60116f 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -378,6 +378,68 @@ llvm::Value* loadOutputVector(llvm::IRBuilder<>& builder, llvm::Value* output) { return builder.CreateLoad(i8PtrTy, vectorPtrPtr, "output_vector"); } +llvm::Value* loadPointerField( + llvm::IRBuilder<>& builder, + llvm::Value* descriptor, + uint64_t offset, + llvm::Type* pointerType, + llvm::StringRef name) { + auto* fieldAddr = builder.CreateConstInBoundsGEP1_64( + builder.getInt8Ty(), descriptor, offset); + auto* fieldPtrPtr = builder.CreatePointerCast(fieldAddr, pointerType->getPointerTo()); + return builder.CreateLoad(pointerType, fieldPtrPtr, name); +} + +llvm::Value* loadDecodedIndex( + llvm::IRBuilder<>& builder, + llvm::Value* decoded, + llvm::Value* row) { + auto* i32Ty = builder.getInt32Ty(); + auto* indices = loadPointerField( + builder, + decoded, + sizeof(void*), + i32Ty->getPointerTo(), + "decoded_indices"); + return builder.CreateLoad(i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); +} + +llvm::Value* loadDecodedRowFieldPointer( + llvm::IRBuilder<>& builder, + llvm::Value* decoded, + int32_t field, + bool nulls) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + const auto firstRowFieldOffset = static_cast(4 * sizeof(void*)); + auto* pointerType = nulls ? builder.getInt64Ty()->getPointerTo() : i8PtrTy; + auto offset = firstRowFieldOffset + static_cast(field) * 2 * sizeof(void*) + + (nulls ? sizeof(void*) : 0); + return loadPointerField( + builder, + decoded, + offset, + pointerType, + nulls ? "decoded_row_field_nulls" : "decoded_row_field_values"); +} + +llvm::Value* loadOutputRowFieldPointer( + llvm::IRBuilder<>& builder, + llvm::Value* output, + int32_t field, + bool nulls) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + const auto firstRowFieldOffset = static_cast(3 * sizeof(void*)); + auto* pointerType = nulls ? builder.getInt64Ty()->getPointerTo() : i8PtrTy; + auto offset = firstRowFieldOffset + static_cast(field) * 2 * sizeof(void*) + + (nulls ? sizeof(void*) : 0); + return loadPointerField( + builder, + output, + offset, + pointerType, + nulls ? "output_row_field_nulls" : "output_row_field_values"); +} + void emitOutputNullBit( llvm::IRBuilder<>& builder, llvm::Value* nulls, @@ -638,6 +700,18 @@ llvm::Value* HashAggrJitCodegen::loadDecodedRowField( llvm::Value* row, int32_t field, HashAggrJitValueKind kind) const { + if (field == 0 || field == 1) { + auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( + builder(), decoded, field, false); + auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); + auto* type = llvmType(kind); + auto* typedValues = builder().CreatePointerCast(rawValues, type->getPointerTo()); + auto* valueAddr = builder().CreateInBoundsGEP( + type, typedValues, builder().CreateZExt(index, builder().getInt64Ty())); + auto* value = builder().CreateLoad(type, valueAddr); + value->setAlignment(llvm::Align(1)); + return value; + } const auto name = decodedRowFieldFunction(kind); BOLT_CHECK( !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); @@ -650,6 +724,30 @@ llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( llvm::Value* decoded, llvm::Value* row, int32_t field) const { + if (field == 0 || field == 1) { + auto* rawNulls = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( + builder(), decoded, field, true); + auto* hasRawNulls = builder().CreateICmpNE( + rawNulls, llvm::ConstantPointerNull::get(builder().getInt64Ty()->getPointerTo())); + auto* nullCheckBlock = llvm::BasicBlock::Create( + module_.getContext(), "row_field_null_check", builder().GetInsertBlock()->getParent()); + auto* rawDoneBlock = llvm::BasicBlock::Create( + module_.getContext(), "row_field_null_done", builder().GetInsertBlock()->getParent()); + builder().CreateCondBr(hasRawNulls, nullCheckBlock, rawDoneBlock); + auto* noNullsEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(nullCheckBlock); + auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); + auto* isNull = ::bytedance::bolt::jit::isDecodedNull(builder(), rawNulls, index); + builder().CreateBr(rawDoneBlock); + auto* nullCheckEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(rawDoneBlock); + auto* fastNull = builder().CreatePHI(builder().getInt1Ty(), 2, "row_field_raw_is_null"); + fastNull->addIncoming(builder().getFalse(), noNullsEnd); + fastNull->addIncoming(isNull, nullCheckEnd); + return fastNull; + } auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); return builder().CreateICmpNE( builder().CreateCall( @@ -700,10 +798,46 @@ void HashAggrJitCodegen::emitPartialAvgResult( llvm::Value* sum, llvm::Value* count, llvm::Value* isNull) const { + auto* sumValues = ::bytedance::bolt::jit::loadOutputRowFieldPointer( + builder(), output, 0, false); + auto* hasRawRowOutput = builder().CreateICmpNE( + sumValues, + llvm::ConstantPointerNull::get( + llvm::PointerType::get(builder().getContext(), 0))); + auto* fastBlock = llvm::BasicBlock::Create( + module_.getContext(), "partial_avg_raw", builder().GetInsertBlock()->getParent()); + auto* helperBlock = llvm::BasicBlock::Create( + module_.getContext(), "partial_avg_helper", builder().GetInsertBlock()->getParent()); + auto* doneBlock = llvm::BasicBlock::Create( + module_.getContext(), "partial_avg_done", builder().GetInsertBlock()->getParent()); + builder().CreateCondBr(hasRawRowOutput, fastBlock, helperBlock); + + builder().SetInsertPoint(fastBlock); + auto* sumTypedValues = + builder().CreatePointerCast(sumValues, builder().getDoubleTy()->getPointerTo()); + auto* countValues = ::bytedance::bolt::jit::loadOutputRowFieldPointer( + builder(), output, 1, false); + auto* countTypedValues = + builder().CreatePointerCast(countValues, builder().getInt64Ty()->getPointerTo()); + auto* row64 = builder().CreateZExt(row, builder().getInt64Ty()); + auto* sumAddr = builder().CreateInBoundsGEP(builder().getDoubleTy(), sumTypedValues, row64); + auto* sumStore = builder().CreateStore(sum, sumAddr); + sumStore->setAlignment(llvm::Align(1)); + auto* countAddr = builder().CreateInBoundsGEP(builder().getInt64Ty(), countTypedValues, row64); + auto* countStore = builder().CreateStore(count, countAddr); + countStore->setAlignment(llvm::Align(1)); + auto* nulls = ::bytedance::bolt::jit::loadOutputNulls(builder(), output); + ::bytedance::bolt::jit::emitOutputNullBit(builder(), nulls, row, isNull); + builder().CreateBr(doneBlock); + + builder().SetInsertPoint(helperBlock); auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction("jit_HashAggrSetPartialAvgDouble"), {vector, row, sum, count, isNull}); + builder().CreateBr(doneBlock); + + builder().SetInsertPoint(doneBlock); } void HashAggrJitCodegen::emitDecimalSumExtract( diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 2f6c379b3..f7bf617a3 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -53,10 +53,15 @@ struct HashAggrJitDecodedInput { // null. This is intentionally row-based rather than base-index-based to keep // generated IR independent of dictionary/null wrapping details. const uint64_t* nulls{nullptr}; - // Original DecodedVector pointer. Raw single-value inputs use the descriptor - // fields above directly; intermediate ROW inputs still use row-field helper - // APIs and therefore need the DecodedVector object. + // Original DecodedVector pointer. Kept as fallback for row-field helpers. const void* decodedVector{nullptr}; + // Raw ROW child fields for intermediate avg merge inputs. The top-level + // ROW may still be dictionary/constant wrapped; 'indices' maps rows to the + // flat child row. Only the first two fields are needed by avg: sum, count. + const void* rowField0Values{nullptr}; + const uint64_t* rowField0Nulls{nullptr}; + const void* rowField1Values{nullptr}; + const uint64_t* rowField1Nulls{nullptr}; }; // Runtime output descriptor consumed by JIT extract functions. GroupingSet @@ -67,6 +72,13 @@ struct HashAggrJitOutput { void* values{nullptr}; uint64_t* nulls{nullptr}; void* vector{nullptr}; + // Raw ROW child fields for partial avg output: field 0 = sum(double), + // field 1 = count(int64). Other outputs leave these null and use 'values' + // or helper fallback via 'vector'. + void* rowField0Values{nullptr}; + uint64_t* rowField0Nulls{nullptr}; + void* rowField1Values{nullptr}; + uint64_t* rowField1Nulls{nullptr}; }; struct HashAggrJitPlanContext { diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 598172bfc..dc6a8f7fa 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -679,8 +679,245 @@ dispatch 从 JIT 路径约 **13%** 的显性热点降为噪声级别。 extract 的 per-row setter helper / dynamic_cast。 - 当前 width8 常见 FLAT primitive 数值聚合已全部为正收益;sum 宽度扫描在 width16 达到约 **1.52x**,更接近 最初 multi_sum POC 的方向性预期。 -- 最大遗留问题是 ROW intermediate/output:`merge_avg_jit` 和 `partial_avg_extract_jit` 仍显著慢于 no-JIT,应 - 优先禁用或实现 ROW descriptor。 +- 截至 raw output descriptor 阶段,最大遗留问题是 ROW intermediate/output:`merge_avg_jit` 和 + `partial_avg_extract_jit` 仍显著慢于 no-JIT;第 10 章继续更新了 ROW descriptor 优化后的最新结果。 - 主路径剩余瓶颈回到真正的 JIT add_dense 生成码、hash/vector encoding、RowContainer 和 result copy 等公共成本; 后续若继续优化,优先考虑 flat/no-null add_dense 快路径、减少 per-row accumulator null clear,以及拆分纯 `GroupingSet` add microbenchmark 来单独观察 kernel 收益。 + +## 10. ROW avg input/output descriptor 优化验证 + +### 10.1 优化内容 + +针对第 9.6.1 节的 P0 瓶颈,本轮为 avg 的 `ROW(sum, count)` intermediate input/output 增加了 raw descriptor: + +1. `HashAggrJitDecodedInput` 增加 `rowField0Values/nulls`、`rowField1Values/nulls`,用于 avg merge 直接读取 + partial 输出的 `sum` / `count` child FlatVector; +2. `HashAggrJitOutput` 增加同名 row field 指针,用于 partial avg extract 直接写 `sum` / `count` child FlatVector; +3. `loadDecodedRowField` / `isDecodedRowFieldNull` 对 field 0/1 走 `GEP + load` / raw null bitmap,避免逐行 + `DecodedVector` ROW field helper; +4. `emitPartialAvgResult` 在存在 row field raw output 时直接写 child values 和 row null bitmap,保留 helper fallback + 以覆盖非预期编码。 + +这个优化利用了 partial avg 的数据流约束:`addIntermediateResults` 的输入来自 `extractAccumulator`,而 +`extractAccumulator` 输出的 ROW child 均为 FLAT,因此 avg merge 拆 ROW 时只需要支持 child FlatVector 快路径。 + +关键代码位置: + +- ROW input/output descriptor 字段:`bolt/jit/aggregation/HashAggrJit.h:61` +- ROW field raw load/null check:`bolt/jit/aggregation/HashAggrJit.cpp:698` / `bolt/jit/aggregation/HashAggrJit.cpp:723` +- partial avg ROW raw output 写入:`bolt/jit/aggregation/HashAggrJit.cpp:795` +- avg merge ROW child raw input 填充:`bolt/exec/GroupingSet.cpp:120` +- partial avg ROW child raw output 填充:`bolt/exec/GroupingSet.cpp:148` +- JIT extract output descriptor 准备:`bolt/exec/GroupingSet.cpp:1194` + +### 10.2 验证命令 + +构建: + +```bash +cmake --build --preset conan-release --target bolt_hashaggr_jit_benchmark --parallel 1 +``` + +完整 benchmark: + +```bash +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_min_iters=5 +``` + +P0 专项复测: + +```bash +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_regex='width(4|8|16|32)_merge_avg' --bm_min_iters=20 + +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_regex='width8_high_card_partial_avg_extract' --bm_min_iters=50 --bm_max_secs=10 +``` + +### 10.3 merge_avg 修复结果 + +ROW input descriptor 后,`merge_avg_jit` 从此前稳定 0.63–0.66x 的最大负收益路径,变为稳定正收益: + +| case | 修复前 speedup | 修复后 no-JIT | 修复后 JIT | 修复后 speedup | +|------|---------------:|--------------:|-----------:|---------------:| +| width4_merge_avg | 0.64x | 4.63ms | 4.02ms | **1.15x** | +| width8_merge_avg | 0.63x | 7.86ms | 5.88ms | **1.34x** | +| width16_merge_avg | 0.66x | 14.83ms | 10.56ms | **1.40x** | +| width32_merge_avg | 0.63x | 31.13ms | 24.21ms | **1.29x** | + +完整 benchmark 的同类结果也保持正收益:width4/8/16/32 分别约 **1.23x / 1.30x / 1.41x / 1.37x**。 + +### 10.4 partial_avg_extract 结果 + +partial avg extract 的 ROW output helper 已被 raw child 写入替换,较第 9.5 节中 80.29ms 的 JIT 路径有明显改善; +但在当前完整查询 benchmark 中,端到端仍有波动且长跑仍略慢于 no-JIT: + +| case | 第 9.5 节 JIT | 修复后 no-JIT | 修复后 JIT | 修复后 speedup | +|------|--------------:|--------------:|-----------:|---------------:| +| width8_high_card_partial_avg_extract | 80.29ms | 60.17ms | 68.19ms | 0.88x | + +对照同一轮 partial sum extract: + +| case | no-JIT | JIT | speedup | +|------|-------:|----:|--------:| +| width8_high_card_partial_sum_extract | 27.50ms | 22.50ms | **1.22x** | + +因此,本轮对 partial avg extract 的结论是:ROW output helper 瓶颈已被削弱,但该 case 的端到端性能还没有稳定转正。 +剩余成本大概率不再只是 ROW child 写出,而是 high-cardinality 场景下每行新 group 初始化、avg accumulator +`sum+count` 更新、RowVector 输出物化以及完整 `copyResults` 公共成本共同导致。 + +### 10.5 完整 benchmark 快照 + +完整 benchmark 命令: + +```bash +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_min_iters=5 +``` + +本轮完整结果的 speedup 汇总如下(speedup = no-JIT / JIT,**> 1 表示 JIT 更快**): + +| 聚合 | width4 | width8 | width16 | width32 | +|------|-------:|-------:|--------:|--------:| +| sum(single) | **1.04x** | **1.34x** | **1.48x** | **1.47x** | +| avg(single) | **1.22x** | **1.27x** | **1.45x** | **1.30x** | +| min(single) | **1.01x** | **1.07x** | **1.27x** | **1.27x** | +| count(single) | **1.28x** | **1.77x** | **2.20x** | **2.28x** | +| sum(merge) | **1.08x** | **1.34x** | **1.49x** | **1.43x** | +| avg(merge) | **1.23x** | **1.30x** | **1.41x** | **1.37x** | +| min(merge) | **1.03x** | **1.11x** | **1.28x** | **1.33x** | +| count(merge) | **1.17x** | **1.57x** | **1.86x** | **1.97x** | + +其他 width8 用例: + +| case | no-JIT | JIT | speedup | +|------|-------:|----:|--------:| +| width8_decimal_sum | 12.05ms | 9.60ms | **1.26x** | +| width8_decimal_avg | 16.15ms | 14.60ms | **1.11x** | +| width8_double_min | 4.95ms | 4.15ms | **1.19x** | +| width8_double_max | 4.13ms | 3.86ms | **1.07x** | +| width8_high_card_partial_avg_extract | 58.15ms | 63.36ms | 0.92x | +| width8_high_card_partial_sum_extract | 24.48ms | 21.34ms | **1.15x** | + +结论:ROW avg input/output descriptor 之后,完整 benchmark 中除 `partial_avg_extract` 外,当前覆盖的主要 +single / merge primitive 聚合均已转为正收益;`count` 和宽 `sum/avg/merge_avg` 收益最稳定。 + +### 10.6 更新后的瓶颈优先级 + +1. **P0 已基本解决:merge_avg ROW input**。`merge_avg_jit` 已从 0.63–0.66x 拉升到 1.15–1.40x,是本轮最主要收益。 +2. **P1:partial_avg_extract 仍需继续拆解**。ROW output raw descriptor 已降低 JIT 绝对耗时,但端到端仍约 0.88x; + 下一步需要用 perf 区分 add/update、新 group 初始化、ROW output materialization 和 `copyResults` 的占比。 +3. **P2:flat/no-null add_dense 快路径**。普通 sum/avg/min 主路径仍有 `indices[row]` 和 per-row null 处理成本。 +4. **P3:减少 per-row accumulator null clear**。对于 no-null input 或 accumulator 已初始化场景,把 null clear 从 per-row + 下沉到更粗粒度。 +5. **P4:min/max compare 和 decimal 专项优化**。收益优先级低于 partial avg extract 和 add_dense 主路径。 + +### 10.7 本轮结论 + +- avg merge 的 ROW intermediate input 已吃到 raw descriptor 优化,最大负收益 case 已转正。 +- partial avg extract 的 ROW output helper 已优化,但 benchmark 仍显示端到端略慢,需要继续 perf 定位剩余成本。 +- HashAggr JIT 当前更适合 sum/count/avg merge 这类宽融合场景;partial avg extract 暂不应作为默认开启 JIT 的依据。 + +## 11. P1:partial_avg_extract 火焰图定位 + +### 11.1 perf 采集与火焰图生成方法 + +对 `width8_high_card_partial_avg_extract` 的 JIT / no-JIT 两条路径分别采样,并生成火焰图。 + +采样(`-F 999 --call-graph dwarf`): + +```bash +# JIT 路径 +/usr/lib/linux-tools-5.15.0-160/perf record -F 999 --call-graph dwarf \ + -o /tmp/bolt-partial-avg-extract-jit.perf.data -- \ + ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=200 --bm_max_secs=20 \ + --bm_regex='^width8_high_card_partial_avg_extract_jit$' + +# no-JIT 路径 +/usr/lib/linux-tools-5.15.0-160/perf record -F 999 --call-graph dwarf \ + -o /tmp/bolt-partial-avg-extract-nojit.perf.data -- \ + ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ + --bm_min_iters=200 --bm_max_secs=20 \ + --bm_regex='^width8_high_card_partial_avg_extract_nojit$' +``` + +折叠栈并用 FlameGraph 生成 SVG: + +```bash +FG=/data00/home/liyang.127/FlameGraph + +/usr/lib/linux-tools-5.15.0-160/perf script -i /tmp/bolt-partial-avg-extract-jit.perf.data \ + | $FG/stackcollapse-perf.pl > /tmp/partial-avg-extract-jit.folded +$FG/flamegraph.pl --title "width8_high_card_partial_avg_extract JIT" \ + /tmp/partial-avg-extract-jit.folded \ + > doc/hashaggr-jit-partial-avg-extract-jit-flamegraph.svg + +/usr/lib/linux-tools-5.15.0-160/perf script -i /tmp/bolt-partial-avg-extract-nojit.perf.data \ + | $FG/stackcollapse-perf.pl > /tmp/partial-avg-extract-nojit.folded +$FG/flamegraph.pl --title "width8_high_card_partial_avg_extract no-JIT" \ + /tmp/partial-avg-extract-nojit.folded \ + > doc/hashaggr-jit-partial-avg-extract-nojit-flamegraph.svg +``` + +> 选用 `--call-graph dwarf` 而非 fp/lbr:Release 二进制开启 `-fomit-frame-pointer`,帧指针回溯会断栈; +> 该机器硬件 PMU/LBR 也不可用(只能用 software `cpu-clock`),dwarf 基于 `.eh_frame`/CFI 回溯, +> 在优化过且含 JIT 匿名段的二进制上能还原完整调用栈,适合做火焰图,代价是数据量大、采样频率需调低到 999。 + +产物火焰图: + +- `doc/hashaggr-jit-partial-avg-extract-jit-flamegraph.svg` +- `doc/hashaggr-jit-partial-avg-extract-nojit-flamegraph.svg` + +### 11.2 采样结构与噪声剥离 + +本次火焰图有两个需要先剥离的结构性噪声,否则会误判热点: + +1. **JIT 后台编译线程**:JIT 编译发生在 `CPUThreadPool0` 上的 `llvm::orc::*` / `PassManager` / + `SelectionDAG` 调用链,在原始火焰图里占比很大,但它属于一次性 plan 编译开销(LRU 缓存命中后不再编译), + 不计入热路径。剥离方式是过滤掉 `llvm::orc` / `*PassManager` / `SelectionDAG` / `MachineFunction` 等编译栈。 +2. **benchmark 主线程**:`bolt_hashaggr_j` 主线程几乎全是 plan 解析(`Parser::parse`、`parseTypeSignature`) + 和动态链接 setup(`elf_dynamic_do_Rela`、`do_lookup_x`),是 query 构建噪声,真正的算子执行在 + `CPUThreadPool0` 执行线程上。 + +剥离后,对执行线程上的真实算子热点做 leaf 归类对比(self time,已排除编译栈)。 + +### 11.3 执行线程热点对比 + +| leaf 热点 | JIT | no-JIT | 说明 | +|----------|----:|-------:|------| +| `[perf-*.map]`(JIT 生成码) | 9.7% | 9.7%(no-JIT 是其它匿名段) | JIT add/extract 生成码 | +| `clear_page_erms` | 8.9% | 9.7% | 内核清零新申请页 | +| `arrayGroupProbe` | 4.8% | 1.6% | hash 探测 | +| `AverageAggregateBase::addRawInput` | 3.2% | 5.6% | avg accumulator 更新 | +| `SumAggregateBase::addRawInput` | 3.2% | 2.4% | 子聚合更新 | +| `MinAggregate::addRawInput` | 2.4% | 4.0% | 子聚合更新 | +| `__memset_avx512` / `get_page_from_freelist` / `_int_malloc` | 合计 ~6% | 合计 ~5% | 新 group 内存分配 | +| `RowContainer::initializeRow` / `HashStringAllocator::clear` | ~3% | ~3% | 新 group 初始化 | +| `RowContainer::extractColumn` | 1.6% | 1.6% | 结果列抽取 | +| `VectorHasher::makeValueIdsFlatNoNulls` | 1.6% | 1.6% | key 编码 | + +关键观察: + +1. **两条路径的执行线程热点几乎重合**:top 热点都是 `clear_page_erms` + 内存分配 + `arrayGroupProbe` + + 各 accumulator 的 `addRawInput`,extract 相关符号(`extractColumn` / ROW child 写出)self time 都不到 2%。 +2. **extract 已经不是这个 case 的瓶颈**:第 10 章 ROW output raw descriptor 已把 extract helper 削掉, + 火焰图里 extract 已沉到噪声级别。partial avg extract 端到端略慢,**不是 extract kernel 导致的**。 +3. **真正的成本是 high-cardinality 的新 group 物化**:该 case groups=batches×batch_size(每行一个新组), + 每个新 group 都要 `clear_page` + `malloc` + `RowContainer::initializeRow`,这部分是 JIT/no-JIT 共有的固定成本, + 且占执行线程相当大比例。JIT 在这部分没有任何优化空间。 +4. **JIT 反而在 `arrayGroupProbe` 上采样更高(4.8% vs 1.6%)**:在「每行新组」的极端高基数下,JIT chunk 的 + group probe 调用方式相对 no-JIT 没有优势,叠加 add kernel 节省有限,导致端到端被新组物化稀释后呈现约 0.9x。 + +### 11.4 P1 结论 + +- partial_avg_extract 的 ROW output 瓶颈(第 9/10 章定位的 helper / dynamic_cast)已被 raw descriptor 解决, + 火焰图确认 extract self time 已 <2%。 +- 该 case 当前端到端约 0.9x 的剩余差距**不在 JIT 可优化范围内**:主导成本是 high-cardinality「每行新 group」 + 带来的 `clear_page` / 内存分配 / `RowContainer::initializeRow` / `arrayGroupProbe`,JIT 与 no-JIT 共享这部分开销, + JIT 能优化的 add/extract kernel 占比已被压得很低。 +- 因此 P1 的处理结论是:**partial_avg_extract 不再作为独立优化项继续深挖**。它代表的是「聚合计算占比极低、 + 新组物化占比极高」的负向场景,应通过**白名单/启发式**避免对这类 high-cardinality partial 聚合启用 JIT, + 而不是继续优化 extract 本身。 +- 真正还能换来 add kernel 收益的是 P2(flat/no-null add_dense 快路径)和 P3(下沉 per-row accumulator null clear), + 它们作用于计算占比高的 case,优先级高于继续打磨 partial_avg_extract。 From 9b766b12e2613840e87951c282bacb4d1763f5a6 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 10 Jun 2026 20:41:55 +0800 Subject: [PATCH 29/98] update doc --- doc/hashaggr-jit-benchmark.md | 174 ++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index dc6a8f7fa..72a7bfaff 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -921,3 +921,177 @@ $FG/flamegraph.pl --title "width8_high_card_partial_avg_extract no-JIT" \ 而不是继续优化 extract 本身。 - 真正还能换来 add kernel 收益的是 P2(flat/no-null add_dense 快路径)和 P3(下沉 per-row accumulator null clear), 它们作用于计算占比高的 case,优先级高于继续打磨 partial_avg_extract。 + +附上此次优化后的benchmark report + +``` +$ ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark +============================================================================ +[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s +============================================================================ +width4_sum_nojit 2.58ms 387.72 +width4_sum_jit 2.36ms 422.95 +---------------------------------------------------------------------------- +width4_avg_nojit 3.15ms 317.76 +width4_avg_jit 2.69ms 372.41 +---------------------------------------------------------------------------- +width4_min_nojit 2.46ms 407.31 +width4_min_jit 2.47ms 405.55 +---------------------------------------------------------------------------- +width4_count_nojit 2.40ms 416.38 +width4_count_jit 1.91ms 524.51 +---------------------------------------------------------------------------- +width4_merge_sum_nojit 3.70ms 270.06 +width4_merge_sum_jit 3.31ms 302.23 +---------------------------------------------------------------------------- +width4_merge_avg_nojit 4.53ms 220.94 +width4_merge_avg_jit 3.67ms 272.67 +---------------------------------------------------------------------------- +width4_merge_min_nojit 3.54ms 282.70 +width4_merge_min_jit 3.37ms 296.70 +---------------------------------------------------------------------------- +width4_merge_count_nojit 3.43ms 291.51 +width4_merge_count_jit 2.87ms 348.60 +---------------------------------------------------------------------------- +width8_sum_nojit 4.60ms 217.19 +width8_sum_jit 3.41ms 293.12 +---------------------------------------------------------------------------- +width8_avg_nojit 5.27ms 189.92 +width8_avg_jit 4.10ms 244.09 +---------------------------------------------------------------------------- +width8_min_nojit 3.76ms 266.31 +width8_min_jit 3.53ms 283.64 +---------------------------------------------------------------------------- +width8_count_nojit 4.31ms 231.77 +width8_count_jit 2.37ms 422.36 +---------------------------------------------------------------------------- +width8_merge_sum_nojit 6.17ms 162.07 +width8_merge_sum_jit 4.78ms 209.35 +---------------------------------------------------------------------------- +width8_merge_avg_nojit 7.26ms 137.75 +width8_merge_avg_jit 5.67ms 176.25 +---------------------------------------------------------------------------- +width8_merge_min_nojit 5.26ms 189.99 +width8_merge_min_jit 4.82ms 207.44 +---------------------------------------------------------------------------- +width8_merge_count_nojit 5.73ms 174.57 +width8_merge_count_jit 3.55ms 281.72 +---------------------------------------------------------------------------- +width16_sum_nojit 8.90ms 112.38 +width16_sum_jit 5.97ms 167.52 +---------------------------------------------------------------------------- +width16_avg_nojit 10.59ms 94.47 +width16_avg_jit 7.29ms 137.08 +---------------------------------------------------------------------------- +width16_min_nojit 7.66ms 130.62 +width16_min_jit 6.34ms 157.71 +---------------------------------------------------------------------------- +width16_count_nojit 7.70ms 129.92 +width16_count_jit 3.51ms 284.55 +---------------------------------------------------------------------------- +width16_merge_sum_nojit 11.17ms 89.55 +width16_merge_sum_jit 7.29ms 137.12 +---------------------------------------------------------------------------- +width16_merge_avg_nojit 14.74ms 67.86 +width16_merge_avg_jit 10.45ms 95.71 +---------------------------------------------------------------------------- +width16_merge_min_nojit 10.53ms 94.92 +width16_merge_min_jit 8.39ms 119.26 +---------------------------------------------------------------------------- +width16_merge_count_nojit 10.02ms 99.78 +width16_merge_count_jit 5.30ms 188.84 +---------------------------------------------------------------------------- +width32_sum_nojit 17.49ms 57.17 +width32_sum_jit 12.44ms 80.38 +---------------------------------------------------------------------------- +width32_avg_nojit 20.01ms 49.97 +width32_avg_jit 15.68ms 63.77 +---------------------------------------------------------------------------- +width32_min_nojit 15.48ms 64.62 +width32_min_jit 12.96ms 77.16 +---------------------------------------------------------------------------- +width32_count_nojit 16.00ms 62.52 +width32_count_jit 7.39ms 135.35 +---------------------------------------------------------------------------- +width32_merge_sum_nojit 22.50ms 44.45 +width32_merge_sum_jit 17.12ms 58.41 +---------------------------------------------------------------------------- +width32_merge_avg_nojit 27.85ms 35.90 +width32_merge_avg_jit 21.32ms 46.90 +---------------------------------------------------------------------------- +width32_merge_min_nojit 20.59ms 48.56 +width32_merge_min_jit 17.22ms 58.09 +---------------------------------------------------------------------------- +width32_merge_count_nojit 19.84ms 50.39 +width32_merge_count_jit 11.76ms 85.03 +---------------------------------------------------------------------------- +width8_decimal_sum_nojit 11.94ms 83.76 +width8_decimal_sum_jit 9.86ms 101.41 +---------------------------------------------------------------------------- +width8_decimal_avg_nojit 16.31ms 61.30 +width8_decimal_avg_jit 14.75ms 67.81 +---------------------------------------------------------------------------- +width8_double_min_nojit 4.97ms 201.05 +width8_double_min_jit 4.17ms 239.76 +---------------------------------------------------------------------------- +width8_double_max_nojit 4.29ms 233.10 +width8_double_max_jit 3.95ms 253.45 +---------------------------------------------------------------------------- +width8_high_card_partial_avg_extract_nojit 62.12ms 16.10 +width8_high_card_partial_avg_extract_jit 68.54ms 14.59 +---------------------------------------------------------------------------- +width8_high_card_partial_sum_extract_nojit 25.27ms 39.58 +width8_high_card_partial_sum_extract_jit 23.10ms 43.29 +---------------------------------------------------------------------------- +``` + +## 12. P2:flat/identity add_dense 快路径验证(已验证无收益,回退) + +### 12.1 优化假设 + +第 9.6.2 / 第 11 章曾把 add_dense 主路径的 `indices[row]` 间接寻址列为 P2 优化点: +flat(identity mapping)输入时 `indices[row] == row`,`loadDecodedValue` +(`bolt/jit/aggregation/HashAggrJit.cpp`)每行先 `index = indices[row]` 再 +`values[index]`,这一级 load 在 flat 下被认为是可省的冗余,预期省掉后能让取值-累加 +循环更利于向量化。 + +### 12.2 实现与验证 + +按方案 A 实现: + +1. `HashAggrJitDecodedInput` 增 `identityMapping` 标记字段; +2. `GroupingSet` 在准备 descriptor 时用 `DecodedVector::isIdentityMapping()` 填标记; +3. `loadDecodedValue` 在 IR 里据标记选择直接用 `row` 还是 `indices[row]`。 + +验证: + +- **功能**:编译通过;dump add_dense IR 确认 identity 分支正确生成、descriptor + trailing bool 的 offset 读取正确(`align 1`)。spark aggregate JIT 单测 + (`bolt_functions_spark_aggregates_test`,`--gtest_filter='*hashAggrJit*'`) + 3 passed / 2 failed,其中 2 个失败(`hashAggrJitMergeAndExtract`、 + `hashAggrJitAllNullGroup`)经 `git stash` 对比确认是**基线既有 bug,与 P2 无关**, + P2 未引入新回归。 +- **性能**:分别构建 baseline / P2 两个 benchmark binary,交替多轮对比 + `width8/16/32` 的 sum/avg/min jit 耗时。 + +### 12.3 实测结果 + +| 实现方式 | 相对基线 | 说明 | +|----------|----------|------| +| select 版(IR 内 `select` 选 index) | **慢约 3–6%** | `select` 仍无条件 load `indices[row]`,额外多算 flag load + select,净增指令 | +| branch 版(控制流跳过 `indices[row]` load) | **基本持平** | 多轮差异均在 ±1–2% 噪声内,无可测收益,且增加 IR 复杂度 | + +以 `width16_sum_jit` 三轮交替为例(branch 版):base 6.48 / 6.24 / 6.48ms, +P2 6.20 / 6.31 / 6.25ms——互有高低,落在噪声范围内。 + +### 12.4 结论 + +- **P2 在当前硬件 / 工作负载上没有可测收益,改动已全部回退到基线。** +- 根因:第 11 章把 `indices[row]` 当瓶颈的假设在实测中不成立。flat 输入下 + `indices` 是连续数组的顺序读,**硬件预取使其几乎零成本**,省掉它换不来收益; + select 版反而因多余指令小幅变慢。 +- 按「只做直接必要、不过度工程」的原则,无收益且增加复杂度的改动不保留。 +- 后续若再优化 add_dense 主路径,方向应转向真正的访存瓶颈(如 accumulator 在 + RowContainer 中的非连续布局),而非已被预取覆盖的 `indices[row]` 间接寻址。 +- P3(下沉 per-row accumulator null clear)的待确认正确性约束(新组创建与首次更新 + 是否同 batch)经评估不成立、争议较大,暂缓,不在本轮实施。 From 9dcdbdd9641b3e74adfa09e6a683cac448d7c57a Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 10 Jun 2026 20:45:30 +0800 Subject: [PATCH 30/98] remove hard-coded offset --- bolt/exec/GroupingSet.cpp | 8 ++--- bolt/jit/aggregation/HashAggrJit.cpp | 53 +++++++++++++++++++++------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index f58c18c14..1c042d888 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1132,10 +1132,10 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitInputVectors_[slotIndex] = arg; hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); hashAggrJitDecodedInputs_[slotIndex] = jit::HashAggrJitDecodedInput{ - hashAggrJitDecoded_[slotIndex].dataAsVoid(), - hashAggrJitDecoded_[slotIndex].indices(), - hashAggrJitDecoded_[slotIndex].nulls(&activeRows_), - &hashAggrJitDecoded_[slotIndex]}; + .values = hashAggrJitDecoded_[slotIndex].dataAsVoid(), + .indices = hashAggrJitDecoded_[slotIndex].indices(), + .nulls = hashAggrJitDecoded_[slotIndex].nulls(&activeRows_), + .decodedVector = &hashAggrJitDecoded_[slotIndex]}; fillHashAggrJitRowFieldInputs( hashAggrJitDecodedInputs_[slotIndex], hashAggrJitDecoded_[slotIndex], diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index a9f60116f..66d38c7b1 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,8 @@ extern "C" { +using bytedance::bolt::jit::HashAggrJitDecodedInput; +using bytedance::bolt::jit::HashAggrJitOutput; using bytedance::bolt::jit::JitDecimalAvgState; using bytedance::bolt::jit::JitDecimalSumState; @@ -49,6 +52,32 @@ void logHashAggrJitFunctionIR( << ir; } +constexpr uint64_t kDecodedInputIndicesOffset = + offsetof(HashAggrJitDecodedInput, indices); +constexpr uint64_t kDecodedInputNullsOffset = + offsetof(HashAggrJitDecodedInput, nulls); +constexpr uint64_t kDecodedInputDecodedVectorOffset = + offsetof(HashAggrJitDecodedInput, decodedVector); +constexpr uint64_t kDecodedInputFirstRowFieldOffset = + offsetof(HashAggrJitDecodedInput, rowField0Values); +constexpr uint64_t kDecodedInputRowFieldNullsOffsetDelta = + offsetof(HashAggrJitDecodedInput, rowField0Nulls) - + offsetof(HashAggrJitDecodedInput, rowField0Values); +constexpr uint64_t kDecodedInputRowFieldStride = + offsetof(HashAggrJitDecodedInput, rowField1Values) - + offsetof(HashAggrJitDecodedInput, rowField0Values); + +constexpr uint64_t kOutputNullsOffset = offsetof(HashAggrJitOutput, nulls); +constexpr uint64_t kOutputVectorOffset = offsetof(HashAggrJitOutput, vector); +constexpr uint64_t kOutputFirstRowFieldOffset = + offsetof(HashAggrJitOutput, rowField0Values); +constexpr uint64_t kOutputRowFieldNullsOffsetDelta = + offsetof(HashAggrJitOutput, rowField0Nulls) - + offsetof(HashAggrJitOutput, rowField0Values); +constexpr uint64_t kOutputRowFieldStride = + offsetof(HashAggrJitOutput, rowField1Values) - + offsetof(HashAggrJitOutput, rowField0Values); + int64_t jitHashAggrAddWithOverflow( bytedance::bolt::int128_t left, bytedance::bolt::int128_t right, @@ -364,7 +393,7 @@ llvm::Value* loadOutputValues(llvm::IRBuilder<>& builder, llvm::Value* output) { llvm::Value* loadOutputNulls(llvm::IRBuilder<>& builder, llvm::Value* output) { auto* i64Ty = builder.getInt64Ty(); auto* nullsAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), output, static_cast(sizeof(void*))); + builder.getInt8Ty(), output, kOutputNullsOffset); auto* nullsPtrPtr = builder.CreatePointerCast(nullsAddr, i64Ty->getPointerTo()->getPointerTo()); return builder.CreateLoad(i64Ty->getPointerTo(), nullsPtrPtr, "output_nulls"); @@ -373,7 +402,7 @@ llvm::Value* loadOutputNulls(llvm::IRBuilder<>& builder, llvm::Value* output) { llvm::Value* loadOutputVector(llvm::IRBuilder<>& builder, llvm::Value* output) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); auto* vectorAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), output, static_cast(2 * sizeof(void*))); + builder.getInt8Ty(), output, kOutputVectorOffset); auto* vectorPtrPtr = builder.CreatePointerCast(vectorAddr, i8PtrTy->getPointerTo()); return builder.CreateLoad(i8PtrTy, vectorPtrPtr, "output_vector"); } @@ -398,7 +427,7 @@ llvm::Value* loadDecodedIndex( auto* indices = loadPointerField( builder, decoded, - sizeof(void*), + kDecodedInputIndicesOffset, i32Ty->getPointerTo(), "decoded_indices"); return builder.CreateLoad(i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); @@ -410,10 +439,10 @@ llvm::Value* loadDecodedRowFieldPointer( int32_t field, bool nulls) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - const auto firstRowFieldOffset = static_cast(4 * sizeof(void*)); auto* pointerType = nulls ? builder.getInt64Ty()->getPointerTo() : i8PtrTy; - auto offset = firstRowFieldOffset + static_cast(field) * 2 * sizeof(void*) + - (nulls ? sizeof(void*) : 0); + auto offset = kDecodedInputFirstRowFieldOffset + + static_cast(field) * kDecodedInputRowFieldStride + + (nulls ? kDecodedInputRowFieldNullsOffsetDelta : 0); return loadPointerField( builder, decoded, @@ -428,10 +457,10 @@ llvm::Value* loadOutputRowFieldPointer( int32_t field, bool nulls) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - const auto firstRowFieldOffset = static_cast(3 * sizeof(void*)); auto* pointerType = nulls ? builder.getInt64Ty()->getPointerTo() : i8PtrTy; - auto offset = firstRowFieldOffset + static_cast(field) * 2 * sizeof(void*) + - (nulls ? sizeof(void*) : 0); + auto offset = kOutputFirstRowFieldOffset + + static_cast(field) * kOutputRowFieldStride + + (nulls ? kOutputRowFieldNullsOffsetDelta : 0); return loadPointerField( builder, output, @@ -558,7 +587,7 @@ llvm::Value* loadDecodedValue( auto* values = builder.CreateLoad(i8PtrTy, valuesPtrPtr, "decoded_values"); auto* indicesAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), decoded, static_cast(sizeof(void*))); + builder.getInt8Ty(), decoded, kDecodedInputIndicesOffset); auto* indicesPtrPtr = builder.CreatePointerCast(indicesAddr, i32Ty->getPointerTo()->getPointerTo()); auto* indices = builder.CreateLoad(i32Ty->getPointerTo(), indicesPtrPtr, "decoded_indices"); @@ -593,7 +622,7 @@ llvm::Value* loadDecodedValue( llvm::Value* loadDecodedNulls(llvm::IRBuilder<>& builder, llvm::Value* decoded) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); auto* nullsAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), decoded, static_cast(2 * sizeof(void*))); + builder.getInt8Ty(), decoded, kDecodedInputNullsOffset); auto* nullsPtrPtr = builder.CreatePointerCast(nullsAddr, i8PtrTy->getPointerTo()); return builder.CreateLoad(i8PtrTy, nullsPtrPtr, "decoded_nulls"); } @@ -618,7 +647,7 @@ llvm::Value* isDecodedNull( llvm::Value* loadDecodedVector(llvm::IRBuilder<>& builder, llvm::Value* decoded) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); auto* decodedVectorAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), decoded, static_cast(3 * sizeof(void*))); + builder.getInt8Ty(), decoded, kDecodedInputDecodedVectorOffset); auto* decodedVectorPtrPtr = builder.CreatePointerCast(decodedVectorAddr, i8PtrTy->getPointerTo()); return builder.CreateLoad(i8PtrTy, decodedVectorPtrPtr, "decoded_vector"); From 266b38c76f132626cd47d95801c40b46d88533e0 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 10 Jun 2026 21:21:22 +0800 Subject: [PATCH 31/98] remove update numNulls_ in jit execution way --- bolt/exec/Aggregate.cpp | 1 - bolt/exec/Aggregate.h | 8 ----- bolt/exec/GroupingSet.cpp | 10 ------ .../prestosql/aggregates/CountAggregate.cpp | 1 - .../prestosql/aggregates/MinMaxAggregates.cpp | 1 - .../sparksql/aggregates/AverageAggregate.cpp | 3 -- .../sparksql/aggregates/DecimalSumAggregate.h | 1 - .../sparksql/aggregates/SumAggregate.cpp | 1 - bolt/jit/aggregation/HashAggrJit.h | 5 --- doc/hash-aggr-jit-todolist.md | 34 +++++++++++++++++++ 10 files changed, 34 insertions(+), 31 deletions(-) create mode 100644 doc/hash-aggr-jit-todolist.md diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index 1aebd48ec..1a08b6fc7 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -346,7 +346,6 @@ jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( .countStar = descriptor.countStar, .mergeInput = descriptor.mergeInput, .decimal = descriptor.decimal, - .initSetsNull = descriptor.initSetsNull, .precision = descriptor.precision, .scale = descriptor.scale, .auxPrecision = descriptor.auxPrecision, diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index 34b5a1903..301709f4c 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -125,14 +125,6 @@ class Aggregate { jit::HashAggrJitSlot createHashAggrJitSlot( int32_t aggregateIndex, const jit::HashAggrJitDescriptor& descriptor) const; - - // HashAggr JIT initGroup marks accumulators as null by writing the null bit - // directly, bypassing setAllNulls/setNull. Since non-JIT extract relies on - // numNulls_ (see isNull()), GroupingSet must keep it in sync after running - // the JIT init path for the corresponding number of new groups. - void addNumNulls(uint64_t count) { - numNulls_ += count; - } #endif void setAllocator(HashStringAllocator* allocator) { diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 1c042d888..3da5bbc06 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1159,16 +1159,6 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitNewGroups_[i] = groups[newGroups[i]]; } chunk.init(hashAggrJitNewGroups_.data(), newGroups.size()); - // JIT initGroup writes the null bit directly without touching - // Aggregate::numNulls_. Non-JIT extract relies on numNulls_ (isNull() - // short-circuits when it is 0), so keep it in sync here, mirroring the - // non-JIT initializeNewGroups/setAllNulls path. - for (const auto& slot : chunk.slots()) { - if (slot.initSetsNull) { - aggregates_[slot.aggregateIndex].function->addNumNulls( - newGroups.size()); - } - } VLOG(1) << "HashAggrJit initialized new groups for chunk " << chunk.functionName() << " newGroups=" << newGroups.size(); } diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 2b98db0d2..9d43c04a8 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -87,7 +87,6 @@ class CountAggregate : public SimpleNumericAggregate { context.isCountStar(), !context.isRawInput, false, - /*initSetsNull=*/false, /*precision=*/0, /*scale=*/0, /*auxPrecision=*/0, diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 20c34620c..7a2e691ef 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -78,7 +78,6 @@ class MinMaxAggregate : public SimpleNumericAggregate { false, !context.isRawInput, false, - /*initSetsNull=*/true, /*precision=*/0, /*scale=*/0, /*auxPrecision=*/0, diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 78aea72d5..db4ae7ca2 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -130,7 +130,6 @@ class AverageAggregate false, true, false, - /*initSetsNull=*/true, /*precision=*/0, /*scale=*/0, /*auxPrecision=*/0, @@ -149,7 +148,6 @@ class AverageAggregate false, false, false, - /*initSetsNull=*/true, /*precision=*/0, /*scale=*/0, /*auxPrecision=*/0, @@ -370,7 +368,6 @@ class DecimalAverageAggregate : public DecimalAggregate { false, !context.isRawInput, true, - /*initSetsNull=*/true, /*precision=*/sumPrecision, /*scale=*/sumScale, /*auxPrecision=*/resultPrecision, diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index ffdf16e25..83ab62556 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -94,7 +94,6 @@ class DecimalSumAggregate : public exec::Aggregate { false, !context.isRawInput, true, - /*initSetsNull=*/true, /*precision=*/resultPrecision, /*scale=*/resultScale, /*auxPrecision=*/0, diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index 44f65ca8d..3645c5bd3 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -173,7 +173,6 @@ class SumAggregate : public SumAggregateBase { false, !context.isRawInput, false, - /*initSetsNull=*/true, /*precision=*/0, /*scale=*/0, /*auxPrecision=*/0, diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index f7bf617a3..5112b4ba5 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -118,10 +118,6 @@ struct HashAggrJitDescriptor { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; - // Whether initGroup marks the accumulator as null for each new group. When - // true, GroupingSet must keep Aggregate::numNulls_ in sync (non-JIT extract - // relies on it), mirroring the non-JIT initializeNewGroups path. - bool initSetsNull{false}; // Result decimal precision/scale, used by decimal extract overflow checks. // Only meaningful when decimal == true. int32_t precision{0}; @@ -172,7 +168,6 @@ struct HashAggrJitSlot { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; - bool initSetsNull{false}; int32_t precision{0}; int32_t scale{0}; int32_t auxPrecision{0}; diff --git a/doc/hash-aggr-jit-todolist.md b/doc/hash-aggr-jit-todolist.md new file mode 100644 index 000000000..15f58c742 --- /dev/null +++ b/doc/hash-aggr-jit-todolist.md @@ -0,0 +1,34 @@ +# HashAggr JIT TODO List + +## Pending + +### [P2] chunk 同时 codegen `add_dense` 和 `add_dense_no_null`,编译时间与产物 ×2 + +**现状** +- 每个 chunk 在 `compile()` 里生成两份 add 函数,仅 `checkInputNulls` 不同: + - `bolt/jit/aggregation/HashAggrJit.cpp:1281-1282` +- 两者差异 100% 在 `genAddDenseIR` 内的 null-check 分支: + - `bolt/jit/aggregation/HashAggrJit.cpp:1016-1029`、`bolt/jit/aggregation/HashAggrJit.cpp:1040` +- 运行时按 batch 级 `inputsMayHaveNulls` 选函数指针,batch 内 stable。 + +**评估结论** +- 问题真实:codegen 时间 ~×2。 +- 但**非 P0**:编译是 per-chunk 一次性、结果缓存在 `module_`/`addDense_`/`addDenseNoNull_` + (`bolt/jit/aggregation/HashAggrJit.cpp:1301-1304`),运行热路径只调用其中一个函数, + 不存在运行期代码膨胀。影响的是编译延迟,不是执行性能。建议定级 **P2**。 + +**为什么 pending** +- 是否值得改,取决于生产实际 workload,目前未知。 + +**决策需要的数据** +- JIT 编译耗时占比 / chunk 编译次数。 +- `inputsMayHaveNulls == false` 的 batch 实际占比。 + +**候选方案** +- 维持现状:若编译耗时占比可忽略,不改。 +- 推荐(建议2,lazy):默认只编 `add_dense`,仅当出现 `inputsMayHaveNulls == false` + 的 batch 时再 lazy 编 `add_dense_no_null`;未就绪前 fallback 到 `add_dense` + (对 no-null 输入同样正确,仅损失少量性能)→ 砍掉常见场景一半编译量,零正确性风险。 +- 不推荐(建议1,运行期 i1 参数):会让 no-null 热路径丢失编译期 dead-branch 消除,反而变慢。 +- 高成本(建议3,alwaysinline + wrapper):理论最优但需重写 add codegen 结构, + 回归面大,仅为省一次性编译,性价比低。 From 3a733762dea5397a1440ef7042a0a0574703c7f6 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 10 Jun 2026 21:45:51 +0800 Subject: [PATCH 32/98] reuse HashAggrJitDescriptor struct in HashAggrJitSlot --- bolt/exec/Aggregate.cpp | 12 +-- bolt/exec/GroupingSet.cpp | 24 ++--- .../prestosql/aggregates/CountAggregate.cpp | 26 ++--- .../prestosql/aggregates/MinMaxAggregates.cpp | 46 ++++----- .../sparksql/aggregates/AverageAggregate.cpp | 95 ++++++++++--------- .../sparksql/aggregates/DecimalSumAggregate.h | 39 ++++---- .../sparksql/aggregates/SumAggregate.cpp | 40 ++++---- bolt/jit/aggregation/HashAggrJit.cpp | 52 +++++----- bolt/jit/aggregation/HashAggrJit.h | 14 +-- doc/hash-aggr-jit-todolist.md | 50 ++++++++++ 10 files changed, 216 insertions(+), 182 deletions(-) diff --git a/bolt/exec/Aggregate.cpp b/bolt/exec/Aggregate.cpp index 1a08b6fc7..824f15077 100644 --- a/bolt/exec/Aggregate.cpp +++ b/bolt/exec/Aggregate.cpp @@ -337,20 +337,10 @@ jit::HashAggrJitSlot Aggregate::createHashAggrJitSlot( const jit::HashAggrJitDescriptor& descriptor) const { return jit::HashAggrJitSlot{ .aggregateIndex = aggregateIndex, - .kind = descriptor.kind, - .inputKind = descriptor.inputKind, - .accumulatorKind = descriptor.accumulatorKind, .offset = accumulatorOffset(), .nullByte = accumulatorNullByte(), .nullMask = accumulatorNullMask(), - .countStar = descriptor.countStar, - .mergeInput = descriptor.mergeInput, - .decimal = descriptor.decimal, - .precision = descriptor.precision, - .scale = descriptor.scale, - .auxPrecision = descriptor.auxPrecision, - .auxScale = descriptor.auxScale, - .ops = descriptor.ops}; + .desc = descriptor}; } #endif diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 3da5bbc06..557abb906 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -120,7 +120,7 @@ void fillHashAggrJitRowFieldInputs( jit::HashAggrJitDecodedInput& input, const DecodedVector& decoded, const jit::HashAggrJitSlot& slot) { - if (!slot.mergeInput || slot.kind != jit::HashAggrJitKind::Avg) { + if (!slot.desc.mergeInput || slot.desc.kind != jit::HashAggrJitKind::Avg) { return; } const auto* base = decoded.base(); @@ -138,7 +138,7 @@ void fillHashAggrJitRowFieldInputs( return; } input.rowField0Values = - hashAggrJitRawInputValues(sumVector.get(), slot.inputKind); + hashAggrJitRawInputValues(sumVector.get(), slot.desc.inputKind); input.rowField0Nulls = sumVector->rawNulls(); input.rowField1Values = hashAggrJitRawInputValues(countVector.get(), jit::HashAggrJitValueKind::Int64); @@ -182,14 +182,14 @@ std::string hashAggrJitSlotDebugString( } out << "]"; } - out << " kind=" << static_cast(slot.kind) - << " inputKind=" << jit::hashAggrJitValueKindName(slot.inputKind) - << " accKind=" << jit::hashAggrJitValueKindName(slot.accumulatorKind) + out << " kind=" << static_cast(slot.desc.kind) + << " inputKind=" << jit::hashAggrJitValueKindName(slot.desc.inputKind) + << " accKind=" << jit::hashAggrJitValueKindName(slot.desc.accumulatorKind) << " offset=" << slot.offset << " nullByte=" << slot.nullByte << " nullMask=" << static_cast(slot.nullMask) - << " countStar=" << slot.countStar - << " mergeInput=" << slot.mergeInput << " decimal=" << slot.decimal - << " ops=" << (slot.ops != nullptr ? slot.ops->id : "null"); + << " countStar=" << slot.desc.countStar + << " mergeInput=" << slot.desc.mergeInput << " decimal=" << slot.desc.decimal + << " ops=" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "null"); return out.str(); } @@ -1109,7 +1109,7 @@ void GroupingSet::runHashAggrJitChunks( skipReason = "selectivity vector is not dense activeRows or has no selections"; break; } - if (slot.countStar) { + if (slot.desc.countStar) { continue; } if (aggregate.inputs.size() != 1) { @@ -1217,7 +1217,7 @@ void GroupingSet::runHashAggrJitExtractChunks( } auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); const auto expectedEncoding = - (isPartial_ && slot.kind == jit::HashAggrJitKind::Avg) + (isPartial_ && slot.desc.kind == jit::HashAggrJitKind::Avg) ? VectorEncoding::Simple::ROW : VectorEncoding::Simple::FLAT; if (aggregateVector->encoding() != expectedEncoding) { @@ -1232,10 +1232,10 @@ void GroupingSet::runHashAggrJitExtractChunks( hashAggrJitOutputs_[slotIndex].vector = aggregateVector.get(); if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { hashAggrJitOutputs_[slotIndex].values = - hashAggrJitRawOutputValues(aggregateVector.get(), slot.accumulatorKind); + hashAggrJitRawOutputValues(aggregateVector.get(), slot.desc.accumulatorKind); hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); } else if (aggregateVector->encoding() == VectorEncoding::Simple::ROW && - slot.kind == jit::HashAggrJitKind::Avg) { + slot.desc.kind == jit::HashAggrJitKind::Avg) { hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); fillHashAggrJitPartialAvgOutput( hashAggrJitOutputs_[slotIndex], aggregateVector.get()); diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 9d43c04a8..c199be1f4 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -81,17 +81,17 @@ class CountAggregate : public SimpleNumericAggregate { inputKind = *maybeInputKind; } return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Count, - inputKind, - jit::HashAggrJitValueKind::Int64, - context.isCountStar(), - !context.isRawInput, - false, - /*precision=*/0, - /*scale=*/0, - /*auxPrecision=*/0, - /*auxScale=*/0, - hashAggrJitOps()}; + .kind = jit::HashAggrJitKind::Count, + .inputKind = inputKind, + .accumulatorKind = jit::HashAggrJitValueKind::Int64, + .countStar = context.isCountStar(), + .mergeInput = !context.isRawInput, + .decimal = false, + .precision = 0, + .scale = 0, + .auxPrecision = 0, + .auxScale = 0, + .ops = hashAggrJitOps()}; } private: @@ -139,11 +139,11 @@ class CountAggregate : public SimpleNumericAggregate { const jit::HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { - llvm::Value* inc = slot.countStar + llvm::Value* inc = slot.desc.countStar ? codegen.builder().getInt64(1) : codegen.castValue( codegen.loadDecodedValue(decoded, row, slot), - slot.inputKind, + slot.desc.inputKind, jit::HashAggrJitValueKind::Int64); addInc(codegen, group, slot, inc); } diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 7a2e691ef..e8f2ced2a 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -72,17 +72,17 @@ class MinMaxAggregate : public SimpleNumericAggregate { return std::nullopt; } return jit::HashAggrJitDescriptor{ - jitKind(), - *inputKind, - *inputKind, - false, - !context.isRawInput, - false, - /*precision=*/0, - /*scale=*/0, - /*auxPrecision=*/0, - /*auxScale=*/0, - hashAggrJitOps()}; + .kind = jitKind(), + .inputKind = *inputKind, + .accumulatorKind = *inputKind, + .countStar = false, + .mergeInput = !context.isRawInput, + .decimal = false, + .precision = 0, + .scale = 0, + .auxPrecision = 0, + .auxScale = 0, + .ops = hashAggrJitOps()}; } private: @@ -91,8 +91,8 @@ class MinMaxAggregate : public SimpleNumericAggregate { llvm::Value* group, const jit::HashAggrJitSlot& slot) { codegen.setAccumulatorNull(group, slot); - auto* type = codegen.llvmType(slot.accumulatorKind); - if (codegen.isFloatKind(slot.accumulatorKind)) { + auto* type = codegen.llvmType(slot.desc.accumulatorKind); + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { codegen.storeValue(group, type, slot.offset, llvm::ConstantFP::get(type, 0.0)); } else { codegen.storeValue(group, type, slot.offset, llvm::ConstantInt::get(type, 0)); @@ -111,16 +111,16 @@ class MinMaxAggregate : public SimpleNumericAggregate { llvm::BasicBlock*) { auto* value = codegen.castValue( codegen.loadDecodedValue(decoded, row, slot), - slot.inputKind, - slot.accumulatorKind); - auto* type = codegen.llvmType(slot.accumulatorKind); + slot.desc.inputKind, + slot.desc.accumulatorKind); + auto* type = codegen.llvmType(slot.desc.accumulatorKind); auto* oldValue = codegen.loadValue(group, type, slot.offset); auto* nullState = codegen.isAccumulatorNull(group, slot); llvm::Value* better = nullptr; - if (codegen.isFloatKind(slot.accumulatorKind)) { + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { auto* oldIsNan = codegen.builder().CreateFCmpUNO(oldValue, oldValue); auto* valueIsNan = codegen.builder().CreateFCmpUNO(value, value); - if (slot.kind == jit::HashAggrJitKind::Min) { + if (slot.desc.kind == jit::HashAggrJitKind::Min) { better = codegen.builder().CreateOr( codegen.builder().CreateAnd(oldIsNan, codegen.builder().CreateNot(valueIsNan)), codegen.builder().CreateAnd( @@ -133,7 +133,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { valueIsNan, codegen.builder().CreateFCmpOLT(oldValue, value))); } } else { - better = slot.kind == jit::HashAggrJitKind::Min + better = slot.desc.kind == jit::HashAggrJitKind::Min ? codegen.builder().CreateICmpSLT(value, oldValue) : codegen.builder().CreateICmpSGT(value, oldValue); } @@ -175,8 +175,8 @@ class MinMaxAggregate : public SimpleNumericAggregate { bool) { // Flat setters exist for i8/i16/i32/i64/f32/f64 only. Int128 (long decimal) // and Bool have no flat setter yet, fall back to non-JIT extract. - return slot.accumulatorKind != jit::HashAggrJitValueKind::Int128 && - slot.accumulatorKind != jit::HashAggrJitValueKind::Bool; + return slot.desc.accumulatorKind != jit::HashAggrJitValueKind::Int128 && + slot.desc.accumulatorKind != jit::HashAggrJitValueKind::Bool; } static void compileHashAggrJitExtract( @@ -184,11 +184,11 @@ class MinMaxAggregate : public SimpleNumericAggregate { llvm::Value* group, const jit::HashAggrJitSlot& slot, const jit::HashAggrJitExtractTarget& target) { - auto* value = codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + auto* value = codegen.loadValue(group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); auto* isNull = codegen.builder().CreateZExt( codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); codegen.emitFlatValue( - target.resultVector, target.row, slot.accumulatorKind, value, isNull); + target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); } static const jit::HashAggrJitOps* hashAggrJitOps() { diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index db4ae7ca2..5a8f53b81 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -124,17 +124,17 @@ class AverageAggregate if (!context.isRawInput) { return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Avg, - jit::HashAggrJitValueKind::Double, - jit::HashAggrJitValueKind::Double, - false, - true, - false, - /*precision=*/0, - /*scale=*/0, - /*auxPrecision=*/0, - /*auxScale=*/0, - hashAggrJitOps()}; + .kind = jit::HashAggrJitKind::Avg, + .inputKind = jit::HashAggrJitValueKind::Double, + .accumulatorKind = jit::HashAggrJitValueKind::Double, + .countStar = false, + .mergeInput = true, + .decimal = false, + .precision = 0, + .scale = 0, + .auxPrecision = 0, + .auxScale = 0, + .ops = hashAggrJitOps()}; } auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); @@ -142,17 +142,17 @@ class AverageAggregate return std::nullopt; } return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Avg, - *inputKind, - jit::HashAggrJitValueKind::Double, - false, - false, - false, - /*precision=*/0, - /*scale=*/0, - /*auxPrecision=*/0, - /*auxScale=*/0, - hashAggrJitOps()}; + .kind = jit::HashAggrJitKind::Avg, + .inputKind = *inputKind, + .accumulatorKind = jit::HashAggrJitValueKind::Double, + .countStar = false, + .mergeInput = false, + .decimal = false, + .precision = 0, + .scale = 0, + .auxPrecision = 0, + .auxScale = 0, + .ops = hashAggrJitOps()}; } private: @@ -163,9 +163,9 @@ class AverageAggregate codegen.setAccumulatorNull(group, slot); codegen.storeValue( group, - codegen.llvmType(slot.accumulatorKind), + codegen.llvmType(slot.desc.accumulatorKind), slot.offset, - llvm::ConstantFP::get(codegen.llvmType(slot.accumulatorKind), 0.0)); + llvm::ConstantFP::get(codegen.llvmType(slot.desc.accumulatorKind), 0.0)); codegen.storeValue( group, codegen.builder().getInt64Ty(), @@ -183,13 +183,13 @@ class AverageAggregate llvm::BasicBlock*) { auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); auto* value = - codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); + codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); codegen.clearAccumulatorNull(group, slot); auto* oldSum = - codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + codegen.loadValue(group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); codegen.storeValue( group, - codegen.llvmType(slot.accumulatorKind), + codegen.llvmType(slot.desc.accumulatorKind), slot.offset, codegen.builder().CreateFAdd(oldSum, value)); auto* oldCount = codegen.loadValue( @@ -234,7 +234,7 @@ class AverageAggregate const jit::HashAggrJitSlot& slot, bool) { // Only double avg (sum=double@offset, count=int64@offset+8) is supported. - return slot.accumulatorKind == jit::HashAggrJitValueKind::Double; + return slot.desc.accumulatorKind == jit::HashAggrJitValueKind::Double; } static void compileHashAggrJitExtract( @@ -361,18 +361,19 @@ class DecimalAverageAggregate : public DecimalAggregate { const auto [resultPrecision, resultScale] = getDecimalPrecisionScale(*this->resultType().get()); return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Avg, - valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 - : jit::HashAggrJitValueKind::Int128, - jit::HashAggrJitValueKind::Int128, - false, - !context.isRawInput, - true, - /*precision=*/sumPrecision, - /*scale=*/sumScale, - /*auxPrecision=*/resultPrecision, - /*auxScale=*/resultScale, - hashAggrJitOps()}; + .kind = jit::HashAggrJitKind::Avg, + .inputKind = valueType->isShortDecimal() + ? jit::HashAggrJitValueKind::Int64 + : jit::HashAggrJitValueKind::Int128, + .accumulatorKind = jit::HashAggrJitValueKind::Int128, + .countStar = false, + .mergeInput = !context.isRawInput, + .decimal = true, + .precision = sumPrecision, + .scale = sumScale, + .auxPrecision = resultPrecision, + .auxScale = resultScale, + .ops = hashAggrJitOps()}; } #endif @@ -637,16 +638,16 @@ class DecimalAverageAggregate : public DecimalAggregate { llvm::BasicBlock*) { auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? "jit_HashAggrUpdateDecimalAvgI128" : "jit_HashAggrUpdateDecimalAvgI64"; codegen.builder().CreateCall( codegen.module().getFunction(helper), {group, codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 + slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? codegen.castValue( - rawValue, slot.inputKind, jit::HashAggrJitValueKind::Int128) + rawValue, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) : rawValue}); } @@ -691,18 +692,18 @@ class DecimalAverageAggregate : public DecimalAggregate { codegen.builder().CreateBr(continueBlock); codegen.builder().SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? "jit_HashAggrMergeDecimalAvgI128" : "jit_HashAggrMergeDecimalAvgI64"; codegen.builder().CreateCall( codegen.module().getFunction(helper), {group, codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 + slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? codegen.castValue( - sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) + sum, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) : sum, count}); codegen.builder().CreateBr(continueBlock); diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 83ab62556..e4fc8b2c1 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -87,18 +87,19 @@ class DecimalSumAggregate : public exec::Aggregate { const auto [resultPrecision, resultScale] = getDecimalPrecisionScale(*sumType_.get()); return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Sum, - valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 - : jit::HashAggrJitValueKind::Int128, - jit::HashAggrJitValueKind::Int128, - false, - !context.isRawInput, - true, - /*precision=*/resultPrecision, - /*scale=*/resultScale, - /*auxPrecision=*/0, - /*auxScale=*/0, - hashAggrJitOps()}; + .kind = jit::HashAggrJitKind::Sum, + .inputKind = valueType->isShortDecimal() + ? jit::HashAggrJitValueKind::Int64 + : jit::HashAggrJitValueKind::Int128, + .accumulatorKind = jit::HashAggrJitValueKind::Int128, + .countStar = false, + .mergeInput = !context.isRawInput, + .decimal = true, + .precision = resultPrecision, + .scale = resultScale, + .auxPrecision = 0, + .auxScale = 0, + .ops = hashAggrJitOps()}; } private: @@ -122,16 +123,16 @@ class DecimalSumAggregate : public exec::Aggregate { llvm::BasicBlock*) { auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? "jit_HashAggrUpdateDecimalSumI128" : "jit_HashAggrUpdateDecimalSumI64"; codegen.builder().CreateCall( codegen.module().getFunction(helper), {group, codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 + slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? codegen.castValue( - rawValue, slot.inputKind, jit::HashAggrJitValueKind::Int128) + rawValue, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) : rawValue}); } @@ -172,18 +173,18 @@ class DecimalSumAggregate : public exec::Aggregate { codegen.builder().CreateBr(continueBlock); codegen.builder().SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.inputKind); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.inputKind == jit::HashAggrJitValueKind::Int128 + const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? "jit_HashAggrMergeDecimalSumI128" : "jit_HashAggrMergeDecimalSumI64"; codegen.builder().CreateCall( codegen.module().getFunction(helper), {group, codegen.builder().getInt32(slot.offset), - slot.inputKind == jit::HashAggrJitValueKind::Int128 + slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 ? codegen.castValue( - sum, slot.inputKind, jit::HashAggrJitValueKind::Int128) + sum, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) : sum, isEmpty}); codegen.builder().CreateBr(continueBlock); diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index 3645c5bd3..a0d9b7764 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -167,17 +167,17 @@ class SumAggregate : public SumAggregateBase { : jit::HashAggrJitValueKind::Int64; return jit::HashAggrJitDescriptor{ - jit::HashAggrJitKind::Sum, - *inputKind, - accumulatorKind, - false, - !context.isRawInput, - false, - /*precision=*/0, - /*scale=*/0, - /*auxPrecision=*/0, - /*auxScale=*/0, - hashAggrJitOps()}; + .kind = jit::HashAggrJitKind::Sum, + .inputKind = *inputKind, + .accumulatorKind = accumulatorKind, + .countStar = false, + .mergeInput = !context.isRawInput, + .decimal = false, + .precision = 0, + .scale = 0, + .auxPrecision = 0, + .auxScale = 0, + .ops = hashAggrJitOps()}; } private: @@ -186,8 +186,8 @@ class SumAggregate : public SumAggregateBase { llvm::Value* group, const jit::HashAggrJitSlot& slot) { codegen.setAccumulatorNull(group, slot); - auto* accType = codegen.llvmType(slot.accumulatorKind); - if (codegen.isFloatKind(slot.accumulatorKind)) { + auto* accType = codegen.llvmType(slot.desc.accumulatorKind); + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { codegen.storeValue( group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); } else { @@ -208,11 +208,11 @@ class SumAggregate : public SumAggregateBase { llvm::BasicBlock*) { auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); auto* value = - codegen.castValue(rawValue, slot.inputKind, slot.accumulatorKind); - auto* accType = codegen.llvmType(slot.accumulatorKind); + codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); + auto* accType = codegen.llvmType(slot.desc.accumulatorKind); codegen.clearAccumulatorNull(group, slot); auto* oldValue = codegen.loadValue(group, accType, slot.offset); - auto* newValue = codegen.isFloatKind(slot.accumulatorKind) + auto* newValue = codegen.isFloatKind(slot.desc.accumulatorKind) ? codegen.builder().CreateFAdd(oldValue, value) : codegen.builder().CreateAdd(oldValue, value); codegen.storeValue(group, accType, slot.offset, newValue); @@ -246,8 +246,8 @@ class SumAggregate : public SumAggregateBase { const jit::HashAggrJitSlot& slot, bool) { // spark sum intermediate type == result type (bigint=bigint / double=double). - return slot.accumulatorKind == jit::HashAggrJitValueKind::Int64 || - slot.accumulatorKind == jit::HashAggrJitValueKind::Double; + return slot.desc.accumulatorKind == jit::HashAggrJitValueKind::Int64 || + slot.desc.accumulatorKind == jit::HashAggrJitValueKind::Double; } static void compileHashAggrJitExtract( @@ -256,11 +256,11 @@ class SumAggregate : public SumAggregateBase { const jit::HashAggrJitSlot& slot, const jit::HashAggrJitExtractTarget& target) { auto* value = - codegen.loadValue(group, codegen.llvmType(slot.accumulatorKind), slot.offset); + codegen.loadValue(group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); auto* isNull = codegen.builder().CreateZExt( codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); codegen.emitFlatValue( - target.resultVector, target.row, slot.accumulatorKind, value, isNull); + target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); } static const jit::HashAggrJitOps* hashAggrJitOps() { diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 66d38c7b1..9fd17fd72 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -594,7 +594,7 @@ llvm::Value* loadDecodedValue( auto* index = builder.CreateLoad( i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); - if (slot.inputKind == HashAggrJitValueKind::Bool) { + if (slot.desc.inputKind == HashAggrJitValueKind::Bool) { auto* wordTy = builder.getInt64Ty(); auto* wordIndex = builder.CreateLShr(index, builder.getInt32(6)); auto* bitIndex = builder.CreateAnd(index, builder.getInt32(63)); @@ -610,7 +610,7 @@ llvm::Value* loadDecodedValue( builder.getInt8Ty()); } - auto* type = llvmType(builder, slot.inputKind); + auto* type = llvmType(builder, slot.desc.inputKind); auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); auto* valueAddr = builder.CreateInBoundsGEP( type, typedValues, builder.CreateZExt(index, builder.getInt64Ty())); @@ -878,7 +878,7 @@ void HashAggrJitCodegen::emitDecimalSumExtract( const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalSum" : "jit_HashAggrExtractFinalDecimalSum"; auto* longDecimal = builder().getInt8( - slot.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), @@ -886,8 +886,8 @@ void HashAggrJitCodegen::emitDecimalSumExtract( row, group, builder().getInt32(slot.offset), - builder().getInt32(slot.precision), - builder().getInt32(slot.scale), + builder().getInt32(slot.desc.precision), + builder().getInt32(slot.desc.scale), longDecimal}); } @@ -900,7 +900,7 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalAvg" : "jit_HashAggrExtractFinalDecimalAvg"; auto* longDecimal = builder().getInt8( - slot.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), @@ -908,8 +908,8 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( row, group, builder().getInt32(slot.offset), - builder().getInt32(slot.precision), - builder().getInt32(slot.scale), + builder().getInt32(slot.desc.precision), + builder().getInt32(slot.desc.scale), longDecimal}); } @@ -956,10 +956,10 @@ bool genInitIR( auto* group = builder.CreateLoad(i8PtrTy, groupAddr); for (const auto& slot : slots) { - if (slot.ops == nullptr || slot.ops->initGroup == nullptr) { + if (slot.desc.ops == nullptr || slot.desc.ops->initGroup == nullptr) { return true; } - slot.ops->initGroup(codegen, group, slot); + slot.desc.ops->initGroup(codegen, group, slot); } auto* next = builder.CreateAdd(index, builder.getInt32(1)); @@ -1013,7 +1013,7 @@ bool genAddDenseIR( auto* nextBlock = llvm::BasicBlock::Create(context, "slot_next", func, end); auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); - if (checkInputNulls && !slot.countStar) { + if (checkInputNulls && !slot.desc.countStar) { auto* nulls = codegen.loadDecodedNulls(decoded); auto* nullCheckBlock = llvm::BasicBlock::Create(context, "slot_null_check", func, end); @@ -1029,11 +1029,11 @@ bool genAddDenseIR( } builder.SetInsertPoint(updateBlock); - if (slot.ops == nullptr) { + if (slot.desc.ops == nullptr) { return true; } auto* addFn = - slot.mergeInput ? slot.ops->addIntermediateResults : slot.ops->addRawInput; + slot.desc.mergeInput ? slot.desc.ops->addIntermediateResults : slot.desc.ops->addRawInput; if (addFn == nullptr) { return true; } @@ -1103,8 +1103,8 @@ bool genExtractIR( auto* end = llvm::BasicBlock::Create(context, "end", func); builder.SetInsertPoint(entry); for (auto i = 0; i < slots.size(); ++i) { - if (slots[i].ops == nullptr || slots[i].ops->canExtract == nullptr || - !slots[i].ops->canExtract(slots[i], partialOutput)) { + if (slots[i].desc.ops == nullptr || slots[i].desc.ops->canExtract == nullptr || + !slots[i].desc.ops->canExtract(slots[i], partialOutput)) { continue; } auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); @@ -1121,16 +1121,16 @@ bool genExtractIR( for (auto i = 0; i < slots.size(); ++i) { const auto& slot = slots[i]; - if (slot.ops == nullptr || slot.ops->canExtract == nullptr || - !slot.ops->canExtract(slot, partialOutput)) { + if (slot.desc.ops == nullptr || slot.desc.ops->canExtract == nullptr || + !slot.desc.ops->canExtract(slot, partialOutput)) { continue; } auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); - if (slot.ops->extract == nullptr) { + if (slot.desc.ops->extract == nullptr) { return true; } - slot.ops->extract( + slot.desc.ops->extract( codegen, group, slot, HashAggrJitExtractTarget{vector, row, partialOutput}); } @@ -1227,12 +1227,12 @@ std::string HashAggrJitChunk::functionName() const { out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" << slots_.size(); for (const auto& slot : slots_) { - out << "_" << (slot.ops != nullptr ? slot.ops->id : "unknown") << "_" - << static_cast(slot.kind) << hashAggrJitValueKindName(slot.inputKind) - << hashAggrJitValueKindName(slot.accumulatorKind) << "o" << slot.offset + out << "_" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "unknown") << "_" + << static_cast(slot.desc.kind) << hashAggrJitValueKindName(slot.desc.inputKind) + << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" << slot.offset << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) - << (slot.countStar ? "s" : "x") << (slot.mergeInput ? "g" : "r") - << (slot.decimal ? "d" : "n"); + << (slot.desc.countStar ? "s" : "x") << (slot.desc.mergeInput ? "g" : "r") + << (slot.desc.decimal ? "d" : "n"); } return out.str(); } @@ -1242,8 +1242,8 @@ bool HashAggrJitChunk::canExtract() const { return false; } for (const auto& slot : slots_) { - if (slot.ops == nullptr || slot.ops->canExtract == nullptr || - !slot.ops->canExtract(slot, partialOutput_)) { + if (slot.desc.ops == nullptr || slot.desc.ops->canExtract == nullptr || + !slot.desc.ops->canExtract(slot, partialOutput_)) { return false; } } diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 5112b4ba5..e757ceccb 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -159,20 +159,12 @@ struct HashAggrJitOps { struct HashAggrJitSlot { int32_t aggregateIndex; - HashAggrJitKind kind; - HashAggrJitValueKind inputKind; - HashAggrJitValueKind accumulatorKind; int32_t offset; int32_t nullByte; uint8_t nullMask; - bool countStar{false}; - bool mergeInput{false}; - bool decimal{false}; - int32_t precision{0}; - int32_t scale{0}; - int32_t auxPrecision{0}; - int32_t auxScale{0}; - const HashAggrJitOps* ops{nullptr}; + // All aggregate-level traits live in the descriptor; IR-side code reads them + // through 'desc'. Only the row-layout fields above are slot-specific. + HashAggrJitDescriptor desc; }; struct HashAggrJitExtractTarget { diff --git a/doc/hash-aggr-jit-todolist.md b/doc/hash-aggr-jit-todolist.md index 15f58c742..63524e95e 100644 --- a/doc/hash-aggr-jit-todolist.md +++ b/doc/hash-aggr-jit-todolist.md @@ -1,5 +1,55 @@ # HashAggr JIT TODO List +## Resolved(已处理,保留遗留风险备忘) + +### Descriptor ↔ Slot 字段重复 / positional init 易错 + +**问题** +- `HashAggrJitDescriptor` 与 `HashAggrJitSlot` 字段大量重复(slot 仅多 row-layout 字段), + 各 aggregate 以 positional 方式构造 descriptor(连续 bool 靠人眼对位,易出低级 bug), + `createHashAggrJitSlot` 又逐字段 copy boilerplate。 + +**本次处理** +- 建议1:6 处 `createHashAggrJitDescriptor` 全改 C++20 designated initializer + (`.kind=`、`.inputKind=` …),消除 positional 对位风险;未 reorder 字段。 +- 建议2:`HashAggrJitSlot` 改为内嵌 `HashAggrJitDescriptor desc`,只保留 + `aggregateIndex/offset/nullByte/nullMask` + `desc`;`createHashAggrJitSlot` + 缩为 4 字段 + `.desc = descriptor`。IR 端与 `functionName()` 等约 70 处 + `slot.` 统一改成 `slot.desc.`;`offset/nullByte/nullMask` 仍在顶层。 +- 验证:无残留旧式访问、无 `descriptor.desc`、无 `.desc.offset` 误改、无双重 `.desc.desc`、 + `HashAggrJitDescriptor::signature()`(裸字段名)未受影响。未重新编译。 + +### 删除 JIT init 对 `Aggregate::numNulls_` 的同步(commit f74cc21160) + +**背景** +- 旧机制:JIT initGroup 直接写 group 的 null bit,但不碰 `Aggregate::numNulls_`; + 而非 JIT extract 的 `isNull()` 依赖 `numNulls_` 短路(为 0 时直接判非 null)。 +- 为弥合差异,曾引入 `HashAggrJitDescriptor/Slot::initSetsNull` 标志 + + `Aggregate::addNumNulls()`,由 `GroupingSet` 在 JIT init 后手工补账。 +- 该机制最初动机:partial agg 中「add 走 JIT、extract 走非 JIT」时的 null diff。 + 现在 add/extract 均支持 JIT,价值大幅下降,且属跨层补丁、封装差。 + +**本次处理** +- 已删除:`Aggregate::addNumNulls()`、`GroupingSet` 中的 `initSetsNull → addNumNulls` + 补账循环、`HashAggrJitDescriptor/Slot::initSetsNull` 字段、各 aggregate 构造处的 + `/*initSetsNull=*/` 实参。 + +**遗留风险(需后续验证 / 补强)** +- 当前 add/extract 仍是 best-effort,存在静默回落非 JIT 的口子,最典型是 **spill**: + - extract 在 `hasSpilled()` / `supportRowBasedOutput_` 时整体跳过 JIT。 + - encoding 不符预期、distinct/mask/sortingKeys 等也会 fallback。 +- 风险场景:某 slot 用了 JIT add(init 只写 null bit、未维护 `numNulls_`),但运行时 + 回落非 JIT extract → `isNull()` 因 `numNulls_==0` 短路,把「全 null 组」误判为非 null + → 输出 0 而非 null(静默错数据)。守护用例:`hashAggrJitAllNullGroup`。 + +**后续待办(择机)** +- 重点回归:带 spill 的 partial agg(尤其全 null 组)。 +- 选一条强化方向之一: + - 做法 1(plan-time 硬门槛):只有「add + extract 全程 JIT 有保证」的 slot 才进 JIT + init/add,会 fallback 的(含可能 spill)一开始就不走 JIT。语义最干净。 + - 做法 2(fallback 现算):保留 fallback,但在非 JIT extract 入口扫一遍 null bit 重建 + `numNulls_`,spill 场景也安全,改动小。 + ## Pending ### [P2] chunk 同时 codegen `add_dense` 和 `add_dense_no_null`,编译时间与产物 ×2 From 2363a20870e1d5147a8f5fb76076d1dc885544c0 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 10 Jun 2026 23:13:50 +0800 Subject: [PATCH 33/98] refactor to speed up compile --- bolt/exec/Aggregate.h | 2 +- bolt/exec/CMakeLists.txt | 2 + bolt/exec/RowContainer.cpp | 86 ----- .../prestosql/aggregates/CountAggregate.cpp | 90 +---- .../prestosql/aggregates/MinMaxAggregates.cpp | 119 +------ .../sparksql/aggregates/AverageAggregate.cpp | 311 +----------------- .../sparksql/aggregates/DecimalSumAggregate.h | 120 +------ .../sparksql/aggregates/SumAggregate.cpp | 191 +---------- bolt/jit/CMakeLists.txt | 6 + bolt/jit/aggregation/HashAggrJit.h | 126 +------ bolt/jit/aggregation/HashAggrJitTypes.h | 162 +++++++++ bolt/jit/aggregation/ops/AvgOps.cpp | 137 ++++++++ bolt/jit/aggregation/ops/CountOps.cpp | 102 ++++++ bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 138 ++++++++ bolt/jit/aggregation/ops/DecimalSumOps.cpp | 131 ++++++++ bolt/jit/aggregation/ops/MinMaxOps.cpp | 109 ++++++ bolt/jit/aggregation/ops/SumOps.cpp | 85 +++++ .../runtime/HashAggrDecimalRuntime.cpp | 157 +++++++++ .../aggregation/runtime/HashAggrRuntime.cpp | 103 ++++++ 19 files changed, 1141 insertions(+), 1036 deletions(-) create mode 100644 bolt/jit/aggregation/HashAggrJitTypes.h create mode 100644 bolt/jit/aggregation/ops/AvgOps.cpp create mode 100644 bolt/jit/aggregation/ops/CountOps.cpp create mode 100644 bolt/jit/aggregation/ops/DecimalAvgOps.cpp create mode 100644 bolt/jit/aggregation/ops/DecimalSumOps.cpp create mode 100644 bolt/jit/aggregation/ops/MinMaxOps.cpp create mode 100644 bolt/jit/aggregation/ops/SumOps.cpp create mode 100644 bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp create mode 100644 bolt/jit/aggregation/runtime/HashAggrRuntime.cpp diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index 301709f4c..bbe2013e9 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -40,7 +40,7 @@ #include "bolt/expression/FunctionSignature.h" #include "bolt/functions/InlineFlatten.h" #ifdef ENABLE_BOLT_JIT -#include "bolt/jit/aggregation/HashAggrJit.h" +#include "bolt/jit/aggregation/HashAggrJitTypes.h" #endif #include "bolt/vector/BaseVector.h" namespace bytedance::bolt::core { diff --git a/bolt/exec/CMakeLists.txt b/bolt/exec/CMakeLists.txt index 7f1a6421d..0b3221240 100644 --- a/bolt/exec/CMakeLists.txt +++ b/bolt/exec/CMakeLists.txt @@ -87,6 +87,8 @@ bolt_add_library( ProbeOperatorState.cpp PushBasedEvent.cpp RowContainer.cpp + ../jit/aggregation/runtime/HashAggrRuntime.cpp + ../jit/aggregation/runtime/HashAggrDecimalRuntime.cpp RowNumber.cpp RowsStreamingWindowBuild.cpp RowsStreamingWindowPartition.cpp diff --git a/bolt/exec/RowContainer.cpp b/bolt/exec/RowContainer.cpp index 12c96caa1..5656e447f 100644 --- a/bolt/exec/RowContainer.cpp +++ b/bolt/exec/RowContainer.cpp @@ -1821,92 +1821,6 @@ __attribute__((__visibility__("default"))) int8_t jit_GetDecodedIsNull( index); } -__attribute__((__visibility__("default"))) void jit_HashAggrResizeVector( - char* vector, - int32_t size) { - reinterpret_cast(vector)->resize(size); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI8( - char* vector, - int32_t row, - int8_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI16( - char* vector, - int32_t row, - int16_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI32( - char* vector, - int32_t row, - int32_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI64( - char* vector, - int32_t row, - int64_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatFloat( - char* vector, - int32_t row, - float value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatDouble( - char* vector, - int32_t row, - double value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetPartialAvgDouble( - char* vector, - int32_t row, - double sum, - int64_t count, - int8_t isNull) { - auto* rowVector = - reinterpret_cast(vector) - ->as(); - auto* sumVector = rowVector->childAt(0)->asFlatVector(); - auto* countVector = rowVector->childAt(1)->asFlatVector(); - if (isNull) { - rowVector->setNull(row, true); - return; - } - rowVector->setNull(row, false); - sumVector->set(row, sum); - countVector->set(row, count); -} - __attribute__((__visibility__("default"))) int8_t jit_ComplexTypeRowEqVectors( const char* row, int32_t offset, diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index c199be1f4..78fcc657a 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -91,96 +91,8 @@ class CountAggregate : public SimpleNumericAggregate { .scale = 0, .auxPrecision = 0, .auxScale = 0, - .ops = hashAggrJitOps()}; + .ops = jit::getCountOps()}; } - - private: - static void compileHashAggrJitInitGroup( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset, - codegen.builder().getInt64(0)); - } - - static void addInc( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - llvm::Value* inc) { - auto* state = - codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset, - codegen.builder().CreateAdd(state, inc)); - } - - static void compileHashAggrJitAddRawInput( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* /*decoded*/, - llvm::Value* /*row*/, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - addInc(codegen, group, slot, codegen.builder().getInt64(1)); - } - - static void compileHashAggrJitAddIntermediateResults( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - llvm::Value* inc = slot.desc.countStar - ? codegen.builder().getInt64(1) - : codegen.castValue( - codegen.loadDecodedValue(decoded, row, slot), - slot.desc.inputKind, - jit::HashAggrJitValueKind::Int64); - addInc(codegen, group, slot, inc); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot&, - bool) { - // count result is always BIGINT and never null. - return true; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - auto* value = codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); - codegen.emitFlatValue( - target.resultVector, - target.row, - jit::HashAggrJitValueKind::Int64, - value, - codegen.builder().getInt8(0)); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "count", - &compileHashAggrJitInitGroup, - &compileHashAggrJitAddRawInput, - &compileHashAggrJitAddIntermediateResults, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } - - public: #endif void toIntermediate( diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index e8f2ced2a..868357cd7 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -82,124 +82,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { .scale = 0, .auxPrecision = 0, .auxScale = 0, - .ops = hashAggrJitOps()}; - } - - private: - static void compileHashAggrJitInitGroup( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.setAccumulatorNull(group, slot); - auto* type = codegen.llvmType(slot.desc.accumulatorKind); - if (codegen.isFloatKind(slot.desc.accumulatorKind)) { - codegen.storeValue(group, type, slot.offset, llvm::ConstantFP::get(type, 0.0)); - } else { - codegen.storeValue(group, type, slot.offset, llvm::ConstantInt::get(type, 0)); - } - } - - // min/max use the same logic for raw input and intermediate merge: both pick - // the better (min/max) of the decoded value and the current accumulator. - static void compileHashAggrJitUpdate( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* value = codegen.castValue( - codegen.loadDecodedValue(decoded, row, slot), - slot.desc.inputKind, - slot.desc.accumulatorKind); - auto* type = codegen.llvmType(slot.desc.accumulatorKind); - auto* oldValue = codegen.loadValue(group, type, slot.offset); - auto* nullState = codegen.isAccumulatorNull(group, slot); - llvm::Value* better = nullptr; - if (codegen.isFloatKind(slot.desc.accumulatorKind)) { - auto* oldIsNan = codegen.builder().CreateFCmpUNO(oldValue, oldValue); - auto* valueIsNan = codegen.builder().CreateFCmpUNO(value, value); - if (slot.desc.kind == jit::HashAggrJitKind::Min) { - better = codegen.builder().CreateOr( - codegen.builder().CreateAnd(oldIsNan, codegen.builder().CreateNot(valueIsNan)), - codegen.builder().CreateAnd( - codegen.builder().CreateNot(valueIsNan), - codegen.builder().CreateFCmpOGT(oldValue, value))); - } else { - better = codegen.builder().CreateAnd( - codegen.builder().CreateNot(oldIsNan), - codegen.builder().CreateOr( - valueIsNan, codegen.builder().CreateFCmpOLT(oldValue, value))); - } - } else { - better = slot.desc.kind == jit::HashAggrJitKind::Min - ? codegen.builder().CreateICmpSLT(value, oldValue) - : codegen.builder().CreateICmpSGT(value, oldValue); - } - auto* shouldStore = codegen.builder().CreateOr(nullState, better); - codegen.storeValue( - group, - type, - slot.offset, - codegen.builder().CreateSelect(shouldStore, value, oldValue)); - codegen.clearAccumulatorNull(group, slot); - } - - static void compileHashAggrJitAddRawInput( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool checkInputNulls, - llvm::BasicBlock* nextBlock) { - compileHashAggrJitUpdate( - codegen, group, decoded, row, slot, checkInputNulls, nextBlock); - } - - static void compileHashAggrJitAddIntermediateResults( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool checkInputNulls, - llvm::BasicBlock* nextBlock) { - compileHashAggrJitUpdate( - codegen, group, decoded, row, slot, checkInputNulls, nextBlock); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot& slot, - bool) { - // Flat setters exist for i8/i16/i32/i64/f32/f64 only. Int128 (long decimal) - // and Bool have no flat setter yet, fall back to non-JIT extract. - return slot.desc.accumulatorKind != jit::HashAggrJitValueKind::Int128 && - slot.desc.accumulatorKind != jit::HashAggrJitValueKind::Bool; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - auto* value = codegen.loadValue(group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); - auto* isNull = codegen.builder().CreateZExt( - codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); - codegen.emitFlatValue( - target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "minmax", - &compileHashAggrJitInitGroup, - &compileHashAggrJitAddRawInput, - &compileHashAggrJitAddIntermediateResults, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; + .ops = jit::getMinMaxOps()}; } protected: diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 5a8f53b81..92ada29c6 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -32,61 +32,6 @@ #include "bolt/functions/lib/aggregates/AverageAggregateBase.h" #include "bolt/functions/sparksql/DecimalUtil.h" -#ifdef ENABLE_BOLT_JIT -#include "bolt/jit/aggregation/HashAggrJit.h" -#include "bolt/type/DecimalUtil.h" - -extern "C" { - -// Partial decimal avg extract: write row(sum:decimal, count:bigint). -// Overflow during sum adjustment -> sum child set to null, count kept. -// (Final decimal avg extract stays on the non-JIT path; the rescale logic is -// too coupled to per-aggregate precision metadata.) -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractPartialDecimalAvg( - char* vector, - int32_t row, - char* group, - int32_t offset, - int32_t /*precision*/, - int32_t /*scale*/, - int8_t /*longDecimal*/) { - auto* state = - reinterpret_cast(group + offset); - auto* rowVector = - reinterpret_cast(vector) - ->as(); - auto* sumVector = - rowVector->childAt(0)->asFlatVector(); - auto* countVector = rowVector->childAt(1)->asFlatVector(); - rowVector->setNull(row, false); - countVector->set(row, state->count); - std::optional adjustedSum = - bytedance::bolt::DecimalUtil::adjustSumForOverflow( - state->sum, state->overflow); - if (adjustedSum.has_value()) { - sumVector->set(row, adjustedSum.value()); - } else { - sumVector->setNull(row, true); - } -} - -// Final decimal avg extract is intentionally not implemented in JIT; the -// declaration exists so the JIT module link succeeds, but it is never called -// because canExtract returns false for the final (non-partial) output. -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractFinalDecimalAvg( - char* /*vector*/, - int32_t /*row*/, - char* /*group*/, - int32_t /*offset*/, - int32_t /*precision*/, - int32_t /*scale*/, - int8_t /*longDecimal*/) {} - -} // extern "C" -#endif - using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { namespace { @@ -134,7 +79,7 @@ class AverageAggregate .scale = 0, .auxPrecision = 0, .auxScale = 0, - .ops = hashAggrJitOps()}; + .ops = jit::getAvgOps()}; } auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); @@ -152,134 +97,8 @@ class AverageAggregate .scale = 0, .auxPrecision = 0, .auxScale = 0, - .ops = hashAggrJitOps()}; - } - - private: - static void compileHashAggrJitInitGroup( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.setAccumulatorNull(group, slot); - codegen.storeValue( - group, - codegen.llvmType(slot.desc.accumulatorKind), - slot.offset, - llvm::ConstantFP::get(codegen.llvmType(slot.desc.accumulatorKind), 0.0)); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().getInt64(0)); - } - - static void compileHashAggrJitAddRawInput( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); - auto* value = - codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); - codegen.clearAccumulatorNull(group, slot); - auto* oldSum = - codegen.loadValue(group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); - codegen.storeValue( - group, - codegen.llvmType(slot.desc.accumulatorKind), - slot.offset, - codegen.builder().CreateFAdd(oldSum, value)); - auto* oldCount = codegen.loadValue( - group, codegen.builder().getInt64Ty(), slot.offset + 8); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); - } - - static void compileHashAggrJitAddIntermediateResults( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - codegen.clearAccumulatorNull(group, slot); - auto* sum = codegen.loadDecodedRowField( - decoded, row, 0, jit::HashAggrJitValueKind::Double); - auto* count = codegen.loadDecodedRowField( - decoded, row, 1, jit::HashAggrJitValueKind::Int64); - auto* oldSum = - codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); - codegen.storeValue( - group, - codegen.builder().getDoubleTy(), - slot.offset, - codegen.builder().CreateFAdd(oldSum, sum)); - auto* oldCount = codegen.loadValue( - group, codegen.builder().getInt64Ty(), slot.offset + 8); - codegen.storeValue( - group, - codegen.builder().getInt64Ty(), - slot.offset + 8, - codegen.builder().CreateAdd(oldCount, count)); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot& slot, - bool) { - // Only double avg (sum=double@offset, count=int64@offset+8) is supported. - return slot.desc.accumulatorKind == jit::HashAggrJitValueKind::Double; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - auto& builder = codegen.builder(); - auto* sum = - codegen.loadValue(group, builder.getDoubleTy(), slot.offset); - auto* count = - codegen.loadValue(group, builder.getInt64Ty(), slot.offset + 8); - if (target.partialOutput) { - // Intermediate output is row(sum:double, count:bigint). All-null group - // yields (0, 0) with a non-null top-level row (isNull = 0), matching the - // non-JIT extractAccumulators path. - codegen.emitPartialAvgResult( - target.resultVector, target.row, sum, count, builder.getInt8(0)); - return; - } - // Final output is double avg. count == 0 means all inputs were null -> null. - auto* isNull = builder.CreateZExt( - builder.CreateICmpEQ(count, builder.getInt64(0)), builder.getInt8Ty()); - auto* countAsDouble = builder.CreateSIToFP(count, builder.getDoubleTy()); - auto* avg = builder.CreateFDiv(sum, countAsDouble); - codegen.emitFlatValue( - target.resultVector, - target.row, - jit::HashAggrJitValueKind::Double, - avg, - isNull); + .ops = jit::getAvgOps()}; } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "avg", - &compileHashAggrJitInitGroup, - &compileHashAggrJitAddRawInput, - &compileHashAggrJitAddIntermediateResults, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } - - public: #endif void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) @@ -373,7 +192,7 @@ class DecimalAverageAggregate : public DecimalAggregate { .scale = sumScale, .auxPrecision = resultPrecision, .auxScale = resultScale, - .ops = hashAggrJitOps()}; + .ops = jit::getDecimalAvgOps()}; } #endif @@ -617,130 +436,6 @@ class DecimalAverageAggregate : public DecimalAggregate { } private: -#ifdef ENABLE_BOLT_JIT - static void compileHashAggrJitInitGroup( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateCall( - codegen.module().getFunction("jit_HashAggrInitDecimalAvg"), - {group, codegen.builder().getInt32(slot.offset)}); - } - - static void compileHashAggrJitAddRawInput( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalAvgI128" - : "jit_HashAggrUpdateDecimalAvgI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - rawValue, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) - : rawValue}); - } - - static void compileHashAggrJitAddIntermediateResults( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock* nextBlock) { - auto* function = codegen.builder().GetInsertBlock()->getParent(); - auto* continueBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge_cont", - function, - nextBlock); - auto* overflowBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge_overflow", - function, - continueBlock); - auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge", - function, - continueBlock); - auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); - auto* count = codegen.loadDecodedRowField( - decoded, row, 1, jit::HashAggrJitValueKind::Int64); - auto* countPositive = codegen.builder().CreateICmpSGT( - count, codegen.builder().getInt64(0)); - auto* isOverflow = codegen.builder().CreateAnd( - sumIsNull, - codegen.builder().CreateAnd( - codegen.builder().CreateNot(countIsNull), countPositive)); - codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); - - codegen.builder().SetInsertPoint(overflowBlock); - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrMergeDecimalAvgI128" - : "jit_HashAggrMergeDecimalAvgI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - sum, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) - : sum, - count}); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(continueBlock); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot&, - bool partialOutput) { - // Only the partial (extractAccumulators) path is JIT-supported for decimal - // avg. Final avg needs the full per-aggregate rescale logic and stays on - // the non-JIT path. - return partialOutput; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - codegen.emitDecimalAvgExtract( - target.resultVector, target.row, group, slot, target.partialOutput); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "avg_decimal", - &compileHashAggrJitInitGroup, - &compileHashAggrJitAddRawInput, - &compileHashAggrJitAddIntermediateResults, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } -#endif - template inline void mergeSumCount( LongDecimalWithOverflowState* accumulator, diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index e4fc8b2c1..4a2ed825c 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -99,126 +99,8 @@ class DecimalSumAggregate : public exec::Aggregate { .scale = resultScale, .auxPrecision = 0, .auxScale = 0, - .ops = hashAggrJitOps()}; + .ops = jit::getDecimalSumOps()}; } - - private: - static void compileHashAggrJitInitGroup( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateCall( - codegen.module().getFunction("jit_HashAggrInitDecimalSum"), - {group, codegen.builder().getInt32(slot.offset)}); - } - - static void compileHashAggrJitAddRawInput( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalSumI128" - : "jit_HashAggrUpdateDecimalSumI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - rawValue, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) - : rawValue}); - } - - static void compileHashAggrJitAddIntermediateResults( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock* nextBlock) { - auto* function = codegen.builder().GetInsertBlock()->getParent(); - auto* continueBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge_cont", - function, - nextBlock); - auto* overflowBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge_overflow", - function, - continueBlock); - auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge", - function, - continueBlock); - auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* isEmpty = codegen.loadDecodedRowField( - decoded, row, 1, jit::HashAggrJitValueKind::Int8); - auto* isNotEmpty = - codegen.builder().CreateICmpEQ(isEmpty, codegen.builder().getInt8(0)); - auto* isOverflow = codegen.builder().CreateAnd(sumIsNull, isNotEmpty); - codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); - - codegen.builder().SetInsertPoint(overflowBlock); - codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); - codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? "jit_HashAggrMergeDecimalSumI128" - : "jit_HashAggrMergeDecimalSumI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == jit::HashAggrJitValueKind::Int128 - ? codegen.castValue( - sum, slot.desc.inputKind, jit::HashAggrJitValueKind::Int128) - : sum, - isEmpty}); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(continueBlock); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot&, - bool) { - return true; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - codegen.emitDecimalSumExtract( - target.resultVector, target.row, group, slot, target.partialOutput); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "sum_decimal", - &compileHashAggrJitInitGroup, - &compileHashAggrJitAddRawInput, - &compileHashAggrJitAddIntermediateResults, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } - - public: #endif void initializeNewGroups( diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index a0d9b7764..fe76bf2f4 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -33,99 +33,6 @@ #include "bolt/functions/lib/aggregates/SumAggregateBase.h" #include "bolt/functions/sparksql/aggregates/DecimalSumAggregate.h" -#ifdef ENABLE_BOLT_JIT -#include "bolt/type/DecimalUtil.h" - -namespace { -// Mirrors DecimalSumAggregate::computeFinalValue: applies overflow adjustment -// and reports whether the value overflows the result precision range. -bytedance::bolt::int128_t jitDecimalSumComputeFinal( - const bytedance::bolt::jit::JitDecimalSumState* state, - int32_t precision, - bool& overflow) { - using bytedance::bolt::DecimalUtil; - bytedance::bolt::int128_t sum = state->sum; - if ((state->overflow == 1 && state->sum < 0) || - (state->overflow == -1 && state->sum > 0)) { - sum = static_cast( - DecimalUtil::kOverflowMultiplier * state->overflow + state->sum); - } else if (state->overflow != 0) { - overflow = true; - return 0; - } - overflow = !DecimalUtil::valueInPrecisionRange(sum, precision); - return sum; -} -} // namespace - -extern "C" { - -// Final decimal sum extract: write FlatVector. Null when the group is -// empty (all inputs null) or the sum overflows the result precision. -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractFinalDecimalSum( - char* vector, - int32_t row, - char* group, - int32_t offset, - int32_t precision, - int32_t /*scale*/, - int8_t /*longDecimal*/) { - auto* state = - reinterpret_cast(group + offset); - auto* flat = reinterpret_cast(vector) - ->as>(); - if (state->isEmpty) { - flat->setNull(row, true); - return; - } - bool overflow = false; - auto result = jitDecimalSumComputeFinal(state, precision, overflow); - if (overflow) { - flat->setNull(row, true); - } else { - flat->set(row, result); - } -} - -// Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractPartialDecimalSum( - char* vector, - int32_t row, - char* group, - int32_t offset, - int32_t precision, - int32_t /*scale*/, - int8_t /*longDecimal*/) { - auto* state = - reinterpret_cast(group + offset); - auto* rowVector = - reinterpret_cast(vector) - ->as(); - auto* sumVector = rowVector->childAt(0) - ->asFlatVector(); - auto* isEmptyVector = rowVector->childAt(1)->asFlatVector(); - rowVector->setNull(row, false); - if (state->isEmpty) { - sumVector->set(row, 0); - isEmptyVector->set(row, true); - return; - } - bool overflow = false; - auto result = jitDecimalSumComputeFinal(state, precision, overflow); - if (overflow) { - sumVector->setNull(row, true); - isEmptyVector->set(row, false); - } else { - sumVector->set(row, result); - isEmptyVector->set(row, state->isEmpty); - } -} - -} // extern "C" -#endif - using namespace bytedance::bolt::functions::aggregate; namespace bytedance::bolt::functions::aggregate::sparksql { @@ -177,104 +84,8 @@ class SumAggregate : public SumAggregateBase { .scale = 0, .auxPrecision = 0, .auxScale = 0, - .ops = hashAggrJitOps()}; + .ops = jit::getSumOps()}; } - - private: - static void compileHashAggrJitInitGroup( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot) { - codegen.setAccumulatorNull(group, slot); - auto* accType = codegen.llvmType(slot.desc.accumulatorKind); - if (codegen.isFloatKind(slot.desc.accumulatorKind)) { - codegen.storeValue( - group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); - } else { - codegen.storeValue( - group, accType, slot.offset, llvm::ConstantInt::get(accType, 0)); - } - } - - // sum uses the same logic for raw input and intermediate merge: add the - // decoded value into the running accumulator. - static void compileHashAggrJitAccumulate( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool, - llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); - auto* value = - codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); - auto* accType = codegen.llvmType(slot.desc.accumulatorKind); - codegen.clearAccumulatorNull(group, slot); - auto* oldValue = codegen.loadValue(group, accType, slot.offset); - auto* newValue = codegen.isFloatKind(slot.desc.accumulatorKind) - ? codegen.builder().CreateFAdd(oldValue, value) - : codegen.builder().CreateAdd(oldValue, value); - codegen.storeValue(group, accType, slot.offset, newValue); - } - - static void compileHashAggrJitAddRawInput( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool checkInputNulls, - llvm::BasicBlock* nextBlock) { - compileHashAggrJitAccumulate( - codegen, group, decoded, row, slot, checkInputNulls, nextBlock); - } - - static void compileHashAggrJitAddIntermediateResults( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* decoded, - llvm::Value* row, - const jit::HashAggrJitSlot& slot, - bool checkInputNulls, - llvm::BasicBlock* nextBlock) { - compileHashAggrJitAccumulate( - codegen, group, decoded, row, slot, checkInputNulls, nextBlock); - } - - static bool canCompileHashAggrJitExtract( - const jit::HashAggrJitSlot& slot, - bool) { - // spark sum intermediate type == result type (bigint=bigint / double=double). - return slot.desc.accumulatorKind == jit::HashAggrJitValueKind::Int64 || - slot.desc.accumulatorKind == jit::HashAggrJitValueKind::Double; - } - - static void compileHashAggrJitExtract( - jit::HashAggrJitCodegen& codegen, - llvm::Value* group, - const jit::HashAggrJitSlot& slot, - const jit::HashAggrJitExtractTarget& target) { - auto* value = - codegen.loadValue(group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); - auto* isNull = codegen.builder().CreateZExt( - codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); - codegen.emitFlatValue( - target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); - } - - static const jit::HashAggrJitOps* hashAggrJitOps() { - static const jit::HashAggrJitOps kOps{ - "sum", - &compileHashAggrJitInitGroup, - &compileHashAggrJitAddRawInput, - &compileHashAggrJitAddIntermediateResults, - &canCompileHashAggrJitExtract, - &compileHashAggrJitExtract}; - return &kOps; - } - - public: #endif }; diff --git a/bolt/jit/CMakeLists.txt b/bolt/jit/CMakeLists.txt index a0af680ce..4962548ed 100644 --- a/bolt/jit/CMakeLists.txt +++ b/bolt/jit/CMakeLists.txt @@ -17,6 +17,12 @@ bolt_add_library( bolt_thrustjit CompiledModule.cpp ThrustJITv2.cpp aggregation/HashAggrJit.cpp + aggregation/ops/CountOps.cpp + aggregation/ops/MinMaxOps.cpp + aggregation/ops/SumOps.cpp + aggregation/ops/AvgOps.cpp + aggregation/ops/DecimalSumOps.cpp + aggregation/ops/DecimalAvgOps.cpp RowContainer/RowContainerCodeGenerator.cpp RowContainer/RowEqVectorsCodeGenerator.cpp ) diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index e757ceccb..642e16983 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -13,124 +13,14 @@ #include #include "bolt/jit/CompiledModule.h" +#include "bolt/jit/aggregation/HashAggrJitTypes.h" #include "bolt/type/Type.h" namespace bytedance::bolt::jit { class HashAggrJitCodegen; - -// JIT-internal accumulator layouts for decimal sum/avg. Shared between the JIT -// codegen runtime helpers and the extract runtime helpers (which live in a -// different translation unit and need DecimalUtil). -struct JitDecimalSumState { - bytedance::bolt::int128_t sum{0}; - int64_t overflow{0}; - bool isEmpty{true}; -}; - -struct JitDecimalAvgState { - bytedance::bolt::int128_t sum{0}; - int64_t count{0}; - int64_t overflow{0}; -}; - -struct HashAggrJitSlot; struct HashAggrJitExtractTarget; -// Runtime input descriptor consumed by JIT add_dense functions. -// GroupingSet prepares one descriptor per aggregate input for each batch by -// decoding the original vector into a flat/constant base plus a single indices -// mapping. This keeps generated IR independent of the batch's original vector -// encoding (flat/dictionary/constant) while allowing the hot loop to load -// values directly instead of calling jit_GetDecodedValue* helpers per row. -struct HashAggrJitDecodedInput { - const void* values{nullptr}; - // Always points to a top-level-row -> base-row mapping. For flat inputs this - // is a consecutive mapping; for constant inputs it maps every row to the - // constant value index. - const int32_t* indices{nullptr}; - // Top-level nulls. If non-null, bit 'row' indicates whether the input row is - // null. This is intentionally row-based rather than base-index-based to keep - // generated IR independent of dictionary/null wrapping details. - const uint64_t* nulls{nullptr}; - // Original DecodedVector pointer. Kept as fallback for row-field helpers. - const void* decodedVector{nullptr}; - // Raw ROW child fields for intermediate avg merge inputs. The top-level - // ROW may still be dictionary/constant wrapped; 'indices' maps rows to the - // flat child row. Only the first two fields are needed by avg: sum, count. - const void* rowField0Values{nullptr}; - const uint64_t* rowField0Nulls{nullptr}; - const void* rowField1Values{nullptr}; - const uint64_t* rowField1Nulls{nullptr}; -}; - -// Runtime output descriptor consumed by JIT extract functions. GroupingSet -// prepares one descriptor per aggregate output after resizing the result vector. -// Primitive flat outputs write values/null bits directly from generated IR; -// complex outputs keep using vector helper fallbacks via 'vector'. -struct HashAggrJitOutput { - void* values{nullptr}; - uint64_t* nulls{nullptr}; - void* vector{nullptr}; - // Raw ROW child fields for partial avg output: field 0 = sum(double), - // field 1 = count(int64). Other outputs leave these null and use 'values' - // or helper fallback via 'vector'. - void* rowField0Values{nullptr}; - uint64_t* rowField0Nulls{nullptr}; - void* rowField1Values{nullptr}; - uint64_t* rowField1Nulls{nullptr}; -}; - -struct HashAggrJitPlanContext { - bool isRawInput{false}; - bool isPartialOutput{false}; - int32_t inputCount{0}; - TypePtr inputType; - - bool isCountStar() const { - return isRawInput && inputCount == 0; - } -}; - -enum class HashAggrJitKind : uint8_t { - Count, - Sum, - Min, - Max, - Avg, -}; - -enum class HashAggrJitValueKind : uint8_t { - Bool, - Int8, - Int16, - Int32, - Int64, - Int128, - Float, - Double, -}; - -struct HashAggrJitDescriptor { - HashAggrJitKind kind; - HashAggrJitValueKind inputKind; - HashAggrJitValueKind accumulatorKind; - bool countStar{false}; - bool mergeInput{false}; - bool decimal{false}; - // Result decimal precision/scale, used by decimal extract overflow checks. - // Only meaningful when decimal == true. - int32_t precision{0}; - int32_t scale{0}; - // Secondary decimal precision/scale. For decimal avg extract, precision/scale - // carry the intermediate sum type and aux* carry the result type. - int32_t auxPrecision{0}; - int32_t auxScale{0}; - const struct HashAggrJitOps* ops{nullptr}; - - std::string signature() const; -}; - struct HashAggrJitOps { using CreateFn = void (*)(HashAggrJitCodegen&, llvm::Value* group, const HashAggrJitSlot&); @@ -157,16 +47,6 @@ struct HashAggrJitOps { ExtractFn extract; }; -struct HashAggrJitSlot { - int32_t aggregateIndex; - int32_t offset; - int32_t nullByte; - uint8_t nullMask; - // All aggregate-level traits live in the descriptor; IR-side code reads them - // through 'desc'. Only the row-layout fields above are slot-specific. - HashAggrJitDescriptor desc; -}; - struct HashAggrJitExtractTarget { llvm::Value* resultVector; llvm::Value* row; @@ -320,10 +200,6 @@ class HashAggrJitChunk { bool disabled_{false}; }; -bool isHashAggrJitSupportedType(TypeKind kind); -std::optional hashAggrJitValueKind(TypeKind kind); -std::string hashAggrJitValueKindName(HashAggrJitValueKind kind); - } // namespace bytedance::bolt::jit #endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h new file mode 100644 index 000000000..f8620f62c --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -0,0 +1,162 @@ +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include +#include +#include + +#include "bolt/type/Type.h" + +// Lightweight, LLVM-free metadata shared between aggregate functions (which +// only produce a HashAggrJitDescriptor) and the JIT codegen layer. Keeping this +// header free of lets Aggregate.h and other non-JIT translation +// units depend on the JIT planning interface without pulling in the heavy LLVM +// IR headers. The codegen-only declarations (HashAggrJitOps function-pointer +// table, HashAggrJitCodegen, HashAggrJitChunk) live in HashAggrJit.h. + +namespace bytedance::bolt::jit { + +// JIT-internal accumulator layouts for decimal sum/avg. Shared between the JIT +// codegen runtime helpers and the extract runtime helpers (which live in a +// different translation unit and need DecimalUtil). +struct JitDecimalSumState { + bytedance::bolt::int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; +}; + +struct JitDecimalAvgState { + bytedance::bolt::int128_t sum{0}; + int64_t count{0}; + int64_t overflow{0}; +}; + +// Runtime input descriptor consumed by JIT add_dense functions. +// GroupingSet prepares one descriptor per aggregate input for each batch by +// decoding the original vector into a flat/constant base plus a single indices +// mapping. This keeps generated IR independent of the batch's original vector +// encoding (flat/dictionary/constant) while allowing the hot loop to load +// values directly instead of calling jit_GetDecodedValue* helpers per row. +struct HashAggrJitDecodedInput { + const void* values{nullptr}; + // Always points to a top-level-row -> base-row mapping. For flat inputs this + // is a consecutive mapping; for constant inputs it maps every row to the + // constant value index. + const int32_t* indices{nullptr}; + // Top-level nulls. If non-null, bit 'row' indicates whether the input row is + // null. This is intentionally row-based rather than base-index-based to keep + // generated IR independent of dictionary/null wrapping details. + const uint64_t* nulls{nullptr}; + // Original DecodedVector pointer. Kept as fallback for row-field helpers. + const void* decodedVector{nullptr}; + // Raw ROW child fields for intermediate avg merge inputs. The top-level + // ROW may still be dictionary/constant wrapped; 'indices' maps rows to the + // flat child row. Only the first two fields are needed by avg: sum, count. + const void* rowField0Values{nullptr}; + const uint64_t* rowField0Nulls{nullptr}; + const void* rowField1Values{nullptr}; + const uint64_t* rowField1Nulls{nullptr}; +}; + +// Runtime output descriptor consumed by JIT extract functions. GroupingSet +// prepares one descriptor per aggregate output after resizing the result vector. +// Primitive flat outputs write values/null bits directly from generated IR; +// complex outputs keep using vector helper fallbacks via 'vector'. +struct HashAggrJitOutput { + void* values{nullptr}; + uint64_t* nulls{nullptr}; + void* vector{nullptr}; + // Raw ROW child fields for partial avg output: field 0 = sum(double), + // field 1 = count(int64). Other outputs leave these null and use 'values' + // or helper fallback via 'vector'. + void* rowField0Values{nullptr}; + uint64_t* rowField0Nulls{nullptr}; + void* rowField1Values{nullptr}; + uint64_t* rowField1Nulls{nullptr}; +}; + +struct HashAggrJitPlanContext { + bool isRawInput{false}; + bool isPartialOutput{false}; + int32_t inputCount{0}; + TypePtr inputType; + + bool isCountStar() const { + return isRawInput && inputCount == 0; + } +}; + +enum class HashAggrJitKind : uint8_t { + Count, + Sum, + Min, + Max, + Avg, +}; + +enum class HashAggrJitValueKind : uint8_t { + Bool, + Int8, + Int16, + Int32, + Int64, + Int128, + Float, + Double, +}; + +// Forward declaration: the codegen function-pointer table is defined in +// HashAggrJit.h (it references llvm:: types). Descriptors only hold a pointer +// to it, so a forward declaration is enough here and keeps this header +// LLVM-free. +struct HashAggrJitOps; + +struct HashAggrJitDescriptor { + HashAggrJitKind kind; + HashAggrJitValueKind inputKind; + HashAggrJitValueKind accumulatorKind; + bool countStar{false}; + bool mergeInput{false}; + bool decimal{false}; + // Result decimal precision/scale, used by decimal extract overflow checks. + // Only meaningful when decimal == true. + int32_t precision{0}; + int32_t scale{0}; + // Secondary decimal precision/scale. For decimal avg extract, precision/scale + // carry the intermediate sum type and aux* carry the result type. + int32_t auxPrecision{0}; + int32_t auxScale{0}; + const HashAggrJitOps* ops{nullptr}; + + std::string signature() const; +}; + +struct HashAggrJitSlot { + int32_t aggregateIndex; + int32_t offset; + int32_t nullByte; + uint8_t nullMask; + // All aggregate-level traits live in the descriptor; IR-side code reads them + // through 'desc'. Only the row-layout fields above are slot-specific. + HashAggrJitDescriptor desc; +}; + +bool isHashAggrJitSupportedType(TypeKind kind); +std::optional hashAggrJitValueKind(TypeKind kind); +std::string hashAggrJitValueKindName(HashAggrJitValueKind kind); + +// Per-aggregate codegen function tables. The definitions (which reference +// llvm:: types) live in bolt/jit/aggregation/ops/*Ops.cpp. Aggregate functions +// only need the returned pointer to populate HashAggrJitDescriptor::ops, so +// these declarations stay LLVM-free here. +const HashAggrJitOps* getCountOps(); +const HashAggrJitOps* getMinMaxOps(); +const HashAggrJitOps* getSumOps(); +const HashAggrJitOps* getAvgOps(); +const HashAggrJitOps* getDecimalSumOps(); +const HashAggrJitOps* getDecimalAvgOps(); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp new file mode 100644 index 000000000..33fca6860 --- /dev/null +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileAvgInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.storeValue( + group, + codegen.llvmType(slot.desc.accumulatorKind), + slot.offset, + llvm::ConstantFP::get(codegen.llvmType(slot.desc.accumulatorKind), 0.0)); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().getInt64(0)); +} + +void compileAvgAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* value = + codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldSum = codegen.loadValue( + group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); + codegen.storeValue( + group, + codegen.llvmType(slot.desc.accumulatorKind), + slot.offset, + codegen.builder().CreateFAdd(oldSum, value)); + auto* oldCount = codegen.loadValue( + group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); +} + +void compileAvgAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + codegen.clearAccumulatorNull(group, slot); + auto* sum = + codegen.loadDecodedRowField(decoded, row, 0, HashAggrJitValueKind::Double); + auto* count = + codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int64); + auto* oldSum = + codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getDoubleTy(), + slot.offset, + codegen.builder().CreateFAdd(oldSum, sum)); + auto* oldCount = codegen.loadValue( + group, codegen.builder().getInt64Ty(), slot.offset + 8); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset + 8, + codegen.builder().CreateAdd(oldCount, count)); +} + +bool canCompileAvgExtract(const HashAggrJitSlot& slot, bool) { + // Only double avg (sum=double@offset, count=int64@offset+8) is supported. + return slot.desc.accumulatorKind == HashAggrJitValueKind::Double; +} + +void compileAvgExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto& builder = codegen.builder(); + auto* sum = codegen.loadValue(group, builder.getDoubleTy(), slot.offset); + auto* count = codegen.loadValue(group, builder.getInt64Ty(), slot.offset + 8); + if (target.partialOutput) { + // Intermediate output is row(sum:double, count:bigint). All-null group + // yields (0, 0) with a non-null top-level row (isNull = 0), matching the + // non-JIT extractAccumulators path. + codegen.emitPartialAvgResult( + target.resultVector, target.row, sum, count, builder.getInt8(0)); + return; + } + // Final output is double avg. count == 0 means all inputs were null -> null. + auto* isNull = builder.CreateZExt( + builder.CreateICmpEQ(count, builder.getInt64(0)), builder.getInt8Ty()); + auto* countAsDouble = builder.CreateSIToFP(count, builder.getDoubleTy()); + auto* avg = builder.CreateFDiv(sum, countAsDouble); + codegen.emitFlatValue( + target.resultVector, + target.row, + HashAggrJitValueKind::Double, + avg, + isNull); +} + +} // namespace + +const HashAggrJitOps* getAvgOps() { + static const HashAggrJitOps kOps{ + "avg", + &compileAvgInitGroup, + &compileAvgAddRawInput, + &compileAvgAddIntermediateResults, + &canCompileAvgExtract, + &compileAvgExtract}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp new file mode 100644 index 000000000..3e586e25d --- /dev/null +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileCountInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset, + codegen.builder().getInt64(0)); +} + +void addInc( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + llvm::Value* inc) { + auto* state = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); + codegen.storeValue( + group, + codegen.builder().getInt64Ty(), + slot.offset, + codegen.builder().CreateAdd(state, inc)); +} + +void compileCountAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* /*decoded*/, + llvm::Value* /*row*/, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + addInc(codegen, group, slot, codegen.builder().getInt64(1)); +} + +void compileCountAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + llvm::Value* inc = slot.desc.countStar + ? codegen.builder().getInt64(1) + : codegen.castValue( + codegen.loadDecodedValue(decoded, row, slot), + slot.desc.inputKind, + HashAggrJitValueKind::Int64); + addInc(codegen, group, slot, inc); +} + +bool canCompileCountExtract(const HashAggrJitSlot&, bool) { + // count result is always BIGINT and never null. + return true; +} + +void compileCountExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto* value = + codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); + codegen.emitFlatValue( + target.resultVector, + target.row, + HashAggrJitValueKind::Int64, + value, + codegen.builder().getInt8(0)); +} + +} // namespace + +const HashAggrJitOps* getCountOps() { + static const HashAggrJitOps kOps{ + "count", + &compileCountInitGroup, + &compileCountAddRawInput, + &compileCountAddIntermediateResults, + &canCompileCountExtract, + &compileCountExtract}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp new file mode 100644 index 000000000..29c503f56 --- /dev/null +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileDecimalAvgInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateCall( + codegen.module().getFunction("jit_HashAggrInitDecimalAvg"), + {group, codegen.builder().getInt32(slot.offset)}); +} + +void compileDecimalAvgAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalAvgI128" + : "jit_HashAggrUpdateDecimalAvgI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? codegen.castValue( + rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128) + : rawValue}); +} + +void compileDecimalAvgAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock* nextBlock) { + auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "avg_decimal_merge", + function, + continueBlock); + auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); + auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); + auto* count = + codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int64); + auto* countPositive = + codegen.builder().CreateICmpSGT(count, codegen.builder().getInt64(0)); + auto* isOverflow = codegen.builder().CreateAnd( + sumIsNull, + codegen.builder().CreateAnd( + codegen.builder().CreateNot(countIsNull), countPositive)); + codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + codegen.builder().SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(mergeBlock); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? "jit_HashAggrMergeDecimalAvgI128" + : "jit_HashAggrMergeDecimalAvgI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? codegen.castValue( + sum, slot.desc.inputKind, HashAggrJitValueKind::Int128) + : sum, + count}); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(continueBlock); +} + +bool canCompileDecimalAvgExtract(const HashAggrJitSlot&, bool partialOutput) { + // Only the partial (extractAccumulators) path is JIT-supported for decimal + // avg. Final avg needs the full per-aggregate rescale logic and stays on + // the non-JIT path. + return partialOutput; +} + +void compileDecimalAvgExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + codegen.emitDecimalAvgExtract( + target.resultVector, target.row, group, slot, target.partialOutput); +} + +} // namespace + +const HashAggrJitOps* getDecimalAvgOps() { + static const HashAggrJitOps kOps{ + "avg_decimal", + &compileDecimalAvgInitGroup, + &compileDecimalAvgAddRawInput, + &compileDecimalAvgAddIntermediateResults, + &canCompileDecimalAvgExtract, + &compileDecimalAvgExtract}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp new file mode 100644 index 000000000..38c226589 --- /dev/null +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileDecimalSumInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateCall( + codegen.module().getFunction("jit_HashAggrInitDecimalSum"), + {group, codegen.builder().getInt32(slot.offset)}); +} + +void compileDecimalSumAddRawInput( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? "jit_HashAggrUpdateDecimalSumI128" + : "jit_HashAggrUpdateDecimalSumI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? codegen.castValue( + rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128) + : rawValue}); +} + +void compileDecimalSumAddIntermediateResults( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock* nextBlock) { + auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto* continueBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge_cont", + function, + nextBlock); + auto* overflowBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge_overflow", + function, + continueBlock); + auto* mergeBlock = llvm::BasicBlock::Create( + codegen.module().getContext(), + "sum_decimal_merge", + function, + continueBlock); + auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); + auto* isEmpty = + codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int8); + auto* isNotEmpty = + codegen.builder().CreateICmpEQ(isEmpty, codegen.builder().getInt8(0)); + auto* isOverflow = codegen.builder().CreateAnd(sumIsNull, isNotEmpty); + codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + codegen.builder().SetInsertPoint(overflowBlock); + codegen.setAccumulatorNull(group, slot); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(mergeBlock); + auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); + codegen.clearAccumulatorNull(group, slot); + const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? "jit_HashAggrMergeDecimalSumI128" + : "jit_HashAggrMergeDecimalSumI64"; + codegen.builder().CreateCall( + codegen.module().getFunction(helper), + {group, + codegen.builder().getInt32(slot.offset), + slot.desc.inputKind == HashAggrJitValueKind::Int128 + ? codegen.castValue( + sum, slot.desc.inputKind, HashAggrJitValueKind::Int128) + : sum, + isEmpty}); + codegen.builder().CreateBr(continueBlock); + + codegen.builder().SetInsertPoint(continueBlock); +} + +bool canCompileDecimalSumExtract(const HashAggrJitSlot&, bool) { + return true; +} + +void compileDecimalSumExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + codegen.emitDecimalSumExtract( + target.resultVector, target.row, group, slot, target.partialOutput); +} + +} // namespace + +const HashAggrJitOps* getDecimalSumOps() { + static const HashAggrJitOps kOps{ + "sum_decimal", + &compileDecimalSumInitGroup, + &compileDecimalSumAddRawInput, + &compileDecimalSumAddIntermediateResults, + &canCompileDecimalSumExtract, + &compileDecimalSumExtract}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp new file mode 100644 index 000000000..2a39d9c0f --- /dev/null +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileMinMaxInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + auto* type = codegen.llvmType(slot.desc.accumulatorKind); + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { + codegen.storeValue(group, type, slot.offset, llvm::ConstantFP::get(type, 0.0)); + } else { + codegen.storeValue(group, type, slot.offset, llvm::ConstantInt::get(type, 0)); + } +} + +// min/max use the same logic for raw input and intermediate merge: both pick +// the better (min/max) of the decoded value and the current accumulator. +void compileMinMaxUpdate( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* value = codegen.castValue( + codegen.loadDecodedValue(decoded, row, slot), + slot.desc.inputKind, + slot.desc.accumulatorKind); + auto* type = codegen.llvmType(slot.desc.accumulatorKind); + auto* oldValue = codegen.loadValue(group, type, slot.offset); + auto* nullState = codegen.isAccumulatorNull(group, slot); + llvm::Value* better = nullptr; + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { + auto* oldIsNan = codegen.builder().CreateFCmpUNO(oldValue, oldValue); + auto* valueIsNan = codegen.builder().CreateFCmpUNO(value, value); + if (slot.desc.kind == HashAggrJitKind::Min) { + better = codegen.builder().CreateOr( + codegen.builder().CreateAnd(oldIsNan, codegen.builder().CreateNot(valueIsNan)), + codegen.builder().CreateAnd( + codegen.builder().CreateNot(valueIsNan), + codegen.builder().CreateFCmpOGT(oldValue, value))); + } else { + better = codegen.builder().CreateAnd( + codegen.builder().CreateNot(oldIsNan), + codegen.builder().CreateOr( + valueIsNan, codegen.builder().CreateFCmpOLT(oldValue, value))); + } + } else { + better = slot.desc.kind == HashAggrJitKind::Min + ? codegen.builder().CreateICmpSLT(value, oldValue) + : codegen.builder().CreateICmpSGT(value, oldValue); + } + auto* shouldStore = codegen.builder().CreateOr(nullState, better); + codegen.storeValue( + group, + type, + slot.offset, + codegen.builder().CreateSelect(shouldStore, value, oldValue)); + codegen.clearAccumulatorNull(group, slot); +} + +bool canCompileMinMaxExtract(const HashAggrJitSlot& slot, bool) { + // Flat setters exist for i8/i16/i32/i64/f32/f64 only. Int128 (long decimal) + // and Bool have no flat setter yet, fall back to non-JIT extract. + return slot.desc.accumulatorKind != HashAggrJitValueKind::Int128 && + slot.desc.accumulatorKind != HashAggrJitValueKind::Bool; +} + +void compileMinMaxExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue( + group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); + auto* isNull = codegen.builder().CreateZExt( + codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); + codegen.emitFlatValue( + target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); +} + +} // namespace + +const HashAggrJitOps* getMinMaxOps() { + static const HashAggrJitOps kOps{ + "minmax", + &compileMinMaxInitGroup, + &compileMinMaxUpdate, + &compileMinMaxUpdate, + &canCompileMinMaxExtract, + &compileMinMaxExtract}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp new file mode 100644 index 000000000..6324034db --- /dev/null +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +namespace bytedance::bolt::jit { + +namespace { + +void compileSumInitGroup( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) { + codegen.setAccumulatorNull(group, slot); + auto* accType = codegen.llvmType(slot.desc.accumulatorKind); + if (codegen.isFloatKind(slot.desc.accumulatorKind)) { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantFP::get(accType, 0.0)); + } else { + codegen.storeValue( + group, accType, slot.offset, llvm::ConstantInt::get(accType, 0)); + } +} + +// sum uses the same logic for raw input and intermediate merge: add the +// decoded value into the running accumulator. +void compileSumAccumulate( + HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* decoded, + llvm::Value* row, + const HashAggrJitSlot& slot, + bool, + llvm::BasicBlock*) { + auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* value = codegen.castValue( + rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); + auto* accType = codegen.llvmType(slot.desc.accumulatorKind); + codegen.clearAccumulatorNull(group, slot); + auto* oldValue = codegen.loadValue(group, accType, slot.offset); + auto* newValue = codegen.isFloatKind(slot.desc.accumulatorKind) + ? codegen.builder().CreateFAdd(oldValue, value) + : codegen.builder().CreateAdd(oldValue, value); + codegen.storeValue(group, accType, slot.offset, newValue); +} + +bool canCompileSumExtract(const HashAggrJitSlot& slot, bool) { + // spark sum intermediate type == result type (bigint=bigint / double=double). + return slot.desc.accumulatorKind == HashAggrJitValueKind::Int64 || + slot.desc.accumulatorKind == HashAggrJitValueKind::Double; +} + +void compileSumExtract( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto* value = codegen.loadValue( + group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); + auto* isNull = codegen.builder().CreateZExt( + codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); + codegen.emitFlatValue( + target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); +} + +} // namespace + +const HashAggrJitOps* getSumOps() { + static const HashAggrJitOps kOps{ + "sum", + &compileSumInitGroup, + &compileSumAccumulate, + &compileSumAccumulate, + &canCompileSumExtract, + &compileSumExtract}; + return &kOps; +} + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp new file mode 100644 index 000000000..9991ab16a --- /dev/null +++ b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Decimal sum/avg extract runtime helpers called from JIT-generated HashAggr +// extract IR. They read the JIT decimal accumulator (JitDecimalSumState / +// JitDecimalAvgState) from 'group + offset', apply overflow / precision +// adjustment and write the result into the output vector. Resolved by the ORC +// JIT through the process global symbol table, so they only need default +// visibility and to be linked into the host process. These mirror the +// per-aggregate computeFinalValue logic but depend only on the type-layer +// DecimalUtil, so they live next to the other HashAggr runtime helpers. + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJitTypes.h" +#include "bolt/type/DecimalUtil.h" +#include "bolt/vector/ComplexVector.h" +#include "bolt/vector/FlatVector.h" + +namespace { + +// Mirrors DecimalSumAggregate::computeFinalValue: applies overflow adjustment +// and reports whether the value overflows the result precision range. +bytedance::bolt::int128_t jitDecimalSumComputeFinal( + const bytedance::bolt::jit::JitDecimalSumState* state, + int32_t precision, + bool& overflow) { + using bytedance::bolt::DecimalUtil; + bytedance::bolt::int128_t sum = state->sum; + if ((state->overflow == 1 && state->sum < 0) || + (state->overflow == -1 && state->sum > 0)) { + sum = static_cast( + DecimalUtil::kOverflowMultiplier * state->overflow + state->sum); + } else if (state->overflow != 0) { + overflow = true; + return 0; + } + overflow = !DecimalUtil::valueInPrecisionRange(sum, precision); + return sum; +} + +} // namespace + +extern "C" { + +// Final decimal sum extract: write FlatVector. Null when the group is +// empty (all inputs null) or the sum overflows the result precision. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int8_t /*longDecimal*/) { + auto* state = + reinterpret_cast(group + offset); + auto* flat = reinterpret_cast(vector) + ->as>(); + if (state->isEmpty) { + flat->setNull(row, true); + return; + } + bool overflow = false; + auto result = jitDecimalSumComputeFinal(state, precision, overflow); + if (overflow) { + flat->setNull(row, true); + } else { + flat->set(row, result); + } +} + +// Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/, + int8_t /*longDecimal*/) { + auto* state = + reinterpret_cast(group + offset); + auto* rowVector = reinterpret_cast(vector) + ->as(); + auto* sumVector = + rowVector->childAt(0)->asFlatVector(); + auto* isEmptyVector = rowVector->childAt(1)->asFlatVector(); + rowVector->setNull(row, false); + if (state->isEmpty) { + sumVector->set(row, 0); + isEmptyVector->set(row, true); + return; + } + bool overflow = false; + auto result = jitDecimalSumComputeFinal(state, precision, overflow); + if (overflow) { + sumVector->setNull(row, true); + isEmptyVector->set(row, false); + } else { + sumVector->set(row, result); + isEmptyVector->set(row, state->isEmpty); + } +} + +// Partial decimal avg extract: write row(sum:decimal, count:bigint). +// Overflow during sum adjustment -> sum child set to null, count kept. +// (Final decimal avg extract stays on the non-JIT path; the rescale logic is +// too coupled to per-aggregate precision metadata.) +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t /*precision*/, + int32_t /*scale*/, + int8_t /*longDecimal*/) { + auto* state = + reinterpret_cast(group + offset); + auto* rowVector = reinterpret_cast(vector) + ->as(); + auto* sumVector = + rowVector->childAt(0)->asFlatVector(); + auto* countVector = rowVector->childAt(1)->asFlatVector(); + rowVector->setNull(row, false); + countVector->set(row, state->count); + std::optional adjustedSum = + bytedance::bolt::DecimalUtil::adjustSumForOverflow( + state->sum, state->overflow); + if (adjustedSum.has_value()) { + sumVector->set(row, adjustedSum.value()); + } else { + sumVector->setNull(row, true); + } +} + +// Final decimal avg extract is intentionally not implemented in JIT; the +// declaration exists so the JIT module link succeeds, but it is never called +// because canExtract returns false for the final (non-partial) output. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalDecimalAvg( + char* /*vector*/, + int32_t /*row*/, + char* /*group*/, + int32_t /*offset*/, + int32_t /*precision*/, + int32_t /*scale*/, + int8_t /*longDecimal*/) {} + +} // extern "C" + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp new file mode 100644 index 000000000..457a9f1ff --- /dev/null +++ b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Runtime helpers called from JIT-generated HashAggr extract IR. These are +// plain extern "C" functions resolved by the ORC JIT through the process +// global symbol table (see ThrustJITv2::Create / LoadLibraryPermanently), so +// they only need default visibility and to be linked into the host process. +// They were previously colocated in RowContainer.cpp purely because the +// jit_GetDecodedValue* helpers already lived there. + +#include "bolt/vector/ComplexVector.h" +#include "bolt/vector/FlatVector.h" + +extern "C" { + +__attribute__((__visibility__("default"))) void jit_HashAggrResizeVector( + char* vector, + int32_t size) { + reinterpret_cast(vector)->resize(size); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI8( + char* vector, + int32_t row, + int8_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI16( + char* vector, + int32_t row, + int16_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI32( + char* vector, + int32_t row, + int32_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI64( + char* vector, + int32_t row, + int64_t value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatFloat( + char* vector, + int32_t row, + float value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatDouble( + char* vector, + int32_t row, + double value, + int8_t isNull) { + auto* flat = reinterpret_cast(vector) + ->as>(); + isNull ? flat->setNull(row, true) : flat->set(row, value); +} + +__attribute__((__visibility__("default"))) void jit_HashAggrSetPartialAvgDouble( + char* vector, + int32_t row, + double sum, + int64_t count, + int8_t isNull) { + auto* rowVector = reinterpret_cast(vector) + ->as(); + auto* sumVector = rowVector->childAt(0)->asFlatVector(); + auto* countVector = rowVector->childAt(1)->asFlatVector(); + if (isNull) { + rowVector->setNull(row, true); + return; + } + rowVector->setNull(row, false); + sumVector->set(row, sum); + countVector->set(row, count); +} + +} // extern "C" From 061ba6d9c8b9f42491f02e8fef96178edc9590b9 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 10 Jun 2026 23:14:03 +0800 Subject: [PATCH 34/98] update doct --- doc/hash-aggr-jit-todolist.md | 41 +++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/doc/hash-aggr-jit-todolist.md b/doc/hash-aggr-jit-todolist.md index 63524e95e..c498aa0ca 100644 --- a/doc/hash-aggr-jit-todolist.md +++ b/doc/hash-aggr-jit-todolist.md @@ -2,6 +2,47 @@ ## Resolved(已处理,保留遗留风险备忘) +### HashAggrJitOps 散布在各 aggregate + Aggregate.h 硬依赖 LLVM 头 + +**问题** +- `Aggregate.h` `#include HashAggrJit.h`,后者 `#include ` 等重头,导致 + 所有 include `Aggregate.h` 的 TU(JIT 开启时)被拖进 LLVM IR 头,编译时间膨胀。 +- 每个 aggregate 子类内嵌一组 `compileHashAggrJit*` static codegen,依赖 IRBuilder; + `DecimalSumAggregate.h` 模板头塞了 ~120 行 codegen,每个实例化点重复展开。 +- runtime helper(`jit_HashAggrSetFlat*` 等)散落在 `RowContainer.cpp`,decimal extract + helper 散落在 `SumAggregate.cpp` / `AverageAggregate.cpp`。 + +**本次处理(四点 + 遗留点,已编译验证)** +1. 剥离 LLVM 头出 `Aggregate.h`: + - 新建 `bolt/jit/aggregation/HashAggrJitTypes.h`(纯 metadata,无 LLVM): + state / decoded&output 描述符 / planContext / enum / `HashAggrJitDescriptor` + (`ops` 持有前向声明的 `HashAggrJitOps*`)/ `HashAggrJitSlot` / 三个自由函数声明 + / `getXxxOps()` 声明。 + - `HashAggrJit.h` 改为 `#include HashAggrJitTypes.h` + 仅保留 codegen-only + (`HashAggrJitOps` / `HashAggrJitExtractTarget` / `HashAggrJitCodegen` / `HashAggrJitChunk`)。 + - `Aggregate.h` 的 include 改为 `HashAggrJitTypes.h`,LLVM 头不再进公共头。 +2. 各 aggregate codegen 迁到 `bolt/jit/aggregation/ops/*Ops.cpp`: + `CountOps / MinMaxOps / SumOps / AvgOps / DecimalSumOps / DecimalAvgOps`,各 `getXxxOps()`; + 编入 `bolt_thrustjit`。aggregate 子类只留 `supportsHashAggrJit` + `createHashAggrJitDescriptor` + (`.ops = jit::getXxxOps()`)。须留类内:MinMax 的虚函数 `jitKind()`、Decimal 的 + `sumType_` / `resultType()` 依赖。 +3. runtime helper 迁到 `bolt/jit/aggregation/runtime/`: + - `HashAggrRuntime.cpp`:`jit_HashAggrResizeVector` / `SetFlat*` / `SetPartialAvgDouble` + (原在 `RowContainer.cpp`)。 + - `HashAggrDecimalRuntime.cpp`(遗留点):`jit_HashAggrExtract{Final,Partial}Decimal{Sum,Avg}` + + `jitDecimalSumComputeFinal`(原在 `SumAggregate.cpp` / `AverageAggregate.cpp`)。 + - 两文件只依赖 vector + `bolt/type/DecimalUtil.h`,编入 `bolt_exec`(同符号空间、 + `ENABLE_EXPORTS`,仍 extern "C" + visibility default,dlsym 可解析)。 +4. 编译验证:`bolt_thrustjit` / `bolt_exec` / `bolt_aggregates` / + `bolt_functions_spark_aggregates` 均通过;nm 确认 `jit_HashAggr*` 12 个符号在新文件 + 以 `T` 导出,旧文件无残留定义。 + +**未做 / 遗留** +- 链接级端到端单测运行验证未做(当前为 release 纯库配置,无可执行 target)。 + 需要时配 `release_with_test` 跑 HashAggr JIT 单测。 +- 已知 `bolt/functions/sparksql/aggregates/CMakeLists.txt` 里 `SumAggregate.cpp` 被列两次 + (历史问题,非本次范围,未改)。 + ### Descriptor ↔ Slot 字段重复 / positional init 易错 **问题** From c15f2561bc70c85c5f28ee1de467eaac0b2111e3 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 00:33:15 +0800 Subject: [PATCH 35/98] remove more helper in ir --- bolt/jit/aggregation/HashAggrJit.cpp | 186 ++++++--------------- bolt/jit/aggregation/HashAggrJit.h | 11 ++ bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 105 +++++++----- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 98 ++++++----- doc/hash-aggr-jit-todolist.md | 44 +++++ doc/hashaggr-jit-benchmark.md | 44 +++++ 6 files changed, 267 insertions(+), 221 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 9fd17fd72..9db6015db 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -21,8 +21,6 @@ extern "C" { using bytedance::bolt::jit::HashAggrJitDecodedInput; using bytedance::bolt::jit::HashAggrJitOutput; -using bytedance::bolt::jit::JitDecimalAvgState; -using bytedance::bolt::jit::JitDecimalSumState; namespace { @@ -78,113 +76,30 @@ constexpr uint64_t kOutputRowFieldStride = offsetof(HashAggrJitOutput, rowField1Values) - offsetof(HashAggrJitOutput, rowField0Values); -int64_t jitHashAggrAddWithOverflow( - bytedance::bolt::int128_t left, - bytedance::bolt::int128_t right, - bytedance::bolt::int128_t& result) { - result = left + right; - if (left > 0 && right > 0 && result < 0) { - return 1; - } - if (left < 0 && right < 0 && result >= 0) { - return -1; - } - return 0; -} - } // namespace -__attribute__((__visibility__("default"))) void jit_HashAggrInitDecimalSum( - char* group, - int32_t offset) { - new (group + offset) JitDecimalSumState(); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrInitDecimalAvg( - char* group, - int32_t offset) { - new (group + offset) JitDecimalAvgState(); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalSumI64( - char* group, - int32_t offset, - int64_t value) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow( - state->sum, static_cast(value), state->sum); - state->isEmpty = false; -} - -__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalSumI128( +// Link anchor: the JIT extract/output runtime helpers live in separate +// translation units (HashAggrRuntime.cpp / HashAggrDecimalRuntime.cpp) and are +// only ever looked up by name through the ORC JIT global symbol table, never +// referenced at C++ link time. Without an explicit reference the linker would +// drop those objects from the final executable and the JIT would fail to +// resolve the symbols. Referencing one symbol per object forces the whole +// object (and thus every helper it defines) to be retained. This TU is always +// pulled in by any JIT user (HashAggrJitChunk), so the anchor propagates. +void jit_HashAggrResizeVector(char* vector, int32_t size); +void jit_HashAggrExtractFinalDecimalSum( + char* vector, + int32_t row, char* group, int32_t offset, - bytedance::bolt::int128_t value) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); - state->isEmpty = false; -} + int32_t precision, + int32_t scale, + int8_t longDecimal); -__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalAvgI64( - char* group, - int32_t offset, - int64_t value) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow( - state->sum, static_cast(value), state->sum); - ++state->count; -} - -__attribute__((__visibility__("default"))) void jit_HashAggrUpdateDecimalAvgI128( - char* group, - int32_t offset, - bytedance::bolt::int128_t value) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); - ++state->count; -} - -__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalSumI64( - char* group, - int32_t offset, - int64_t value, - int8_t isEmpty) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow( - state->sum, static_cast(value), state->sum); - state->isEmpty = state->isEmpty && static_cast(isEmpty); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalSumI128( - char* group, - int32_t offset, - bytedance::bolt::int128_t value, - int8_t isEmpty) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); - state->isEmpty = state->isEmpty && static_cast(isEmpty); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalAvgI64( - char* group, - int32_t offset, - int64_t value, - int64_t count) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow( - state->sum, static_cast(value), state->sum); - state->count += count; -} - -__attribute__((__visibility__("default"))) void jit_HashAggrMergeDecimalAvgI128( - char* group, - int32_t offset, - bytedance::bolt::int128_t value, - int64_t count) { - auto* state = reinterpret_cast(group + offset); - state->overflow += jitHashAggrAddWithOverflow(state->sum, value, state->sum); - state->count += count; -} +[[maybe_unused]] __attribute__((used)) const void* const + kHashAggrRuntimeLinkAnchors[] = { + reinterpret_cast(&jit_HashAggrResizeVector), + reinterpret_cast(&jit_HashAggrExtractFinalDecimalSum)}; } // extern "C" @@ -233,36 +148,6 @@ void ensureBuiltinDeclarations(llvm::Module& module) { declareFunction( module, "jit_GetDecodedRowFieldIsNull", i8Ty, {i8PtrTy, i32Ty, i32Ty}); declareFunction(module, "jit_GetDecodedIsNull", i8Ty, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_HashAggrInitDecimalSum", voidTy, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_HashAggrInitDecimalAvg", voidTy, {i8PtrTy, i32Ty}); - declareFunction( - module, "jit_HashAggrUpdateDecimalSumI64", voidTy, {i8PtrTy, i32Ty, i64Ty}); - declareFunction( - module, "jit_HashAggrUpdateDecimalSumI128", voidTy, {i8PtrTy, i32Ty, i128Ty}); - declareFunction( - module, "jit_HashAggrUpdateDecimalAvgI64", voidTy, {i8PtrTy, i32Ty, i64Ty}); - declareFunction( - module, "jit_HashAggrUpdateDecimalAvgI128", voidTy, {i8PtrTy, i32Ty, i128Ty}); - declareFunction( - module, - "jit_HashAggrMergeDecimalSumI64", - voidTy, - {i8PtrTy, i32Ty, i64Ty, i8Ty}); - declareFunction( - module, - "jit_HashAggrMergeDecimalSumI128", - voidTy, - {i8PtrTy, i32Ty, i128Ty, i8Ty}); - declareFunction( - module, - "jit_HashAggrMergeDecimalAvgI64", - voidTy, - {i8PtrTy, i32Ty, i64Ty, i64Ty}); - declareFunction( - module, - "jit_HashAggrMergeDecimalAvgI128", - voidTy, - {i8PtrTy, i32Ty, i128Ty, i64Ty}); declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); declareFunction(module, "jit_HashAggrSetFlatI8", voidTy, {i8PtrTy, i32Ty, i8Ty, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatI16", voidTy, {i8PtrTy, i32Ty, i16Ty, i8Ty}); @@ -913,6 +798,39 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( longDecimal}); } +void HashAggrJitCodegen::emitDecimalAddWithOverflow( + llvm::Value* group, + int32_t sumOffset, + int32_t overflowOffset, + llvm::Value* addend) const { + auto& b = builder(); + auto* i128Ty = b.getInt128Ty(); + auto* i64Ty = b.getInt64Ty(); + auto* zero128 = llvm::ConstantInt::get(i128Ty, 0); + + auto* oldSum = loadValue(group, i128Ty, sumOffset); + auto* newSum = b.CreateAdd(oldSum, addend); + storeValue(group, i128Ty, sumOffset, newSum); + + // Mirror jitHashAggrAddWithOverflow: + // +1 if a>0 && b>0 && result<0 (positive overflow) + // -1 if a<0 && b<0 && result>=0 (negative overflow) + auto* aPos = b.CreateICmpSGT(oldSum, zero128); + auto* bPos = b.CreateICmpSGT(addend, zero128); + auto* rNeg = b.CreateICmpSLT(newSum, zero128); + auto* posOverflow = b.CreateAnd(b.CreateAnd(aPos, bPos), rNeg); + + auto* aNeg = b.CreateICmpSLT(oldSum, zero128); + auto* bNeg = b.CreateICmpSLT(addend, zero128); + auto* rNonNeg = b.CreateICmpSGE(newSum, zero128); + auto* negOverflow = b.CreateAnd(b.CreateAnd(aNeg, bNeg), rNonNeg); + + auto* carry = b.CreateSub( + b.CreateZExt(posOverflow, i64Ty), b.CreateZExt(negOverflow, i64Ty)); + auto* oldOverflow = loadValue(group, i64Ty, overflowOffset); + storeValue(group, i64Ty, overflowOffset, b.CreateAdd(oldOverflow, carry)); +} + namespace { bool genAddDenseIR( diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 642e16983..7feb2d313 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -133,6 +133,17 @@ class HashAggrJitCodegen { const HashAggrJitSlot& slot, bool partialOutput) const; + // Inline i128 accumulate-with-overflow used by decimal sum/avg add+merge. + // Loads the i128 sum at 'group + sumOffset' and the i64 overflow counter at + // 'group + overflowOffset', computes sum += addend, updates the overflow + // counter by the carry direction (mirrors jitHashAggrAddWithOverflow), and + // stores both back. Replaces the per-row runtime helper call with pure IR. + void emitDecimalAddWithOverflow( + llvm::Value* group, + int32_t sumOffset, + int32_t overflowOffset, + llvm::Value* addend) const; + private: llvm::Module& module_; llvm::IRBuilder<>* builder_{nullptr}; diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index 29c503f56..fbcec7bdf 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -5,20 +5,38 @@ #ifdef ENABLE_BOLT_JIT +#include + #include "bolt/jit/aggregation/HashAggrJit.h" namespace bytedance::bolt::jit { namespace { +// Field offsets within JitDecimalAvgState, relative to slot.offset. +constexpr int32_t kSumOffset = + static_cast(offsetof(JitDecimalAvgState, sum)); +constexpr int32_t kCountOffset = + static_cast(offsetof(JitDecimalAvgState, count)); +constexpr int32_t kOverflowOffset = + static_cast(offsetof(JitDecimalAvgState, overflow)); + void compileDecimalAvgInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot) { + auto& b = codegen.builder(); codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateCall( - codegen.module().getFunction("jit_HashAggrInitDecimalAvg"), - {group, codegen.builder().getInt32(slot.offset)}); + // sum = 0 (i128), count = 0 (i64), overflow = 0 (i64). + codegen.storeValue( + group, + b.getInt128Ty(), + slot.offset + kSumOffset, + llvm::ConstantInt::get(b.getInt128Ty(), 0)); + codegen.storeValue( + group, b.getInt64Ty(), slot.offset + kCountOffset, b.getInt64(0)); + codegen.storeValue( + group, b.getInt64Ty(), slot.offset + kOverflowOffset, b.getInt64(0)); } void compileDecimalAvgAddRawInput( @@ -29,19 +47,21 @@ void compileDecimalAvgAddRawInput( const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { + auto& b = codegen.builder(); auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* value = + codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalAvgI128" - : "jit_HashAggrUpdateDecimalAvgI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? codegen.castValue( - rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128) - : rawValue}); + codegen.emitDecimalAddWithOverflow( + group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + // ++count. + auto* oldCount = + codegen.loadValue(group, b.getInt64Ty(), slot.offset + kCountOffset); + codegen.storeValue( + group, + b.getInt64Ty(), + slot.offset + kCountOffset, + b.CreateAdd(oldCount, b.getInt64(1))); } void compileDecimalAvgAddIntermediateResults( @@ -52,7 +72,8 @@ void compileDecimalAvgAddIntermediateResults( const HashAggrJitSlot& slot, bool, llvm::BasicBlock* nextBlock) { - auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto& b = codegen.builder(); + auto* function = b.GetInsertBlock()->getParent(); auto* continueBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "avg_decimal_merge_cont", @@ -64,44 +85,38 @@ void compileDecimalAvgAddIntermediateResults( function, continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "avg_decimal_merge", - function, - continueBlock); + codegen.module().getContext(), "avg_decimal_merge", function, continueBlock); auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); auto* count = codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int64); - auto* countPositive = - codegen.builder().CreateICmpSGT(count, codegen.builder().getInt64(0)); - auto* isOverflow = codegen.builder().CreateAnd( - sumIsNull, - codegen.builder().CreateAnd( - codegen.builder().CreateNot(countIsNull), countPositive)); - codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); - - codegen.builder().SetInsertPoint(overflowBlock); + auto* countPositive = b.CreateICmpSGT(count, b.getInt64(0)); + auto* isOverflow = b.CreateAnd( + sumIsNull, b.CreateAnd(b.CreateNot(countIsNull), countPositive)); + b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); + + b.SetInsertPoint(overflowBlock); codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateBr(continueBlock); + b.CreateBr(continueBlock); - codegen.builder().SetInsertPoint(mergeBlock); + b.SetInsertPoint(mergeBlock); auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); + auto* value = + codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? "jit_HashAggrMergeDecimalAvgI128" - : "jit_HashAggrMergeDecimalAvgI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? codegen.castValue( - sum, slot.desc.inputKind, HashAggrJitValueKind::Int128) - : sum, - count}); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(continueBlock); + codegen.emitDecimalAddWithOverflow( + group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + // count += incoming count. + auto* oldCount = + codegen.loadValue(group, b.getInt64Ty(), slot.offset + kCountOffset); + codegen.storeValue( + group, + b.getInt64Ty(), + slot.offset + kCountOffset, + b.CreateAdd(oldCount, count)); + b.CreateBr(continueBlock); + + b.SetInsertPoint(continueBlock); } bool canCompileDecimalAvgExtract(const HashAggrJitSlot&, bool partialOutput) { diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 38c226589..f829c9ad2 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -5,20 +5,38 @@ #ifdef ENABLE_BOLT_JIT +#include + #include "bolt/jit/aggregation/HashAggrJit.h" namespace bytedance::bolt::jit { namespace { +// Field offsets within JitDecimalSumState, relative to slot.offset. +constexpr int32_t kSumOffset = + static_cast(offsetof(JitDecimalSumState, sum)); +constexpr int32_t kOverflowOffset = + static_cast(offsetof(JitDecimalSumState, overflow)); +constexpr int32_t kIsEmptyOffset = + static_cast(offsetof(JitDecimalSumState, isEmpty)); + void compileDecimalSumInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot) { + auto& b = codegen.builder(); codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateCall( - codegen.module().getFunction("jit_HashAggrInitDecimalSum"), - {group, codegen.builder().getInt32(slot.offset)}); + // sum = 0 (i128), overflow = 0 (i64), isEmpty = true (i8). + codegen.storeValue( + group, + b.getInt128Ty(), + slot.offset + kSumOffset, + llvm::ConstantInt::get(b.getInt128Ty(), 0)); + codegen.storeValue( + group, b.getInt64Ty(), slot.offset + kOverflowOffset, b.getInt64(0)); + codegen.storeValue( + group, b.getInt8Ty(), slot.offset + kIsEmptyOffset, b.getInt8(1)); } void compileDecimalSumAddRawInput( @@ -29,19 +47,15 @@ void compileDecimalSumAddRawInput( const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { + auto& b = codegen.builder(); auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* value = + codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? "jit_HashAggrUpdateDecimalSumI128" - : "jit_HashAggrUpdateDecimalSumI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? codegen.castValue( - rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128) - : rawValue}); + codegen.emitDecimalAddWithOverflow( + group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + codegen.storeValue( + group, b.getInt8Ty(), slot.offset + kIsEmptyOffset, b.getInt8(0)); } void compileDecimalSumAddIntermediateResults( @@ -52,7 +66,8 @@ void compileDecimalSumAddIntermediateResults( const HashAggrJitSlot& slot, bool, llvm::BasicBlock* nextBlock) { - auto* function = codegen.builder().GetInsertBlock()->getParent(); + auto& b = codegen.builder(); + auto* function = b.GetInsertBlock()->getParent(); auto* continueBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "sum_decimal_merge_cont", @@ -64,40 +79,39 @@ void compileDecimalSumAddIntermediateResults( function, continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), - "sum_decimal_merge", - function, - continueBlock); + codegen.module().getContext(), "sum_decimal_merge", function, continueBlock); auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* isEmpty = + auto* incomingIsEmpty = codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int8); - auto* isNotEmpty = - codegen.builder().CreateICmpEQ(isEmpty, codegen.builder().getInt8(0)); - auto* isOverflow = codegen.builder().CreateAnd(sumIsNull, isNotEmpty); - codegen.builder().CreateCondBr(isOverflow, overflowBlock, mergeBlock); + auto* isNotEmpty = b.CreateICmpEQ(incomingIsEmpty, b.getInt8(0)); + auto* isOverflow = b.CreateAnd(sumIsNull, isNotEmpty); + b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); - codegen.builder().SetInsertPoint(overflowBlock); + b.SetInsertPoint(overflowBlock); codegen.setAccumulatorNull(group, slot); - codegen.builder().CreateBr(continueBlock); + b.CreateBr(continueBlock); - codegen.builder().SetInsertPoint(mergeBlock); + b.SetInsertPoint(mergeBlock); auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); + auto* value = + codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - const auto helper = slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? "jit_HashAggrMergeDecimalSumI128" - : "jit_HashAggrMergeDecimalSumI64"; - codegen.builder().CreateCall( - codegen.module().getFunction(helper), - {group, - codegen.builder().getInt32(slot.offset), - slot.desc.inputKind == HashAggrJitValueKind::Int128 - ? codegen.castValue( - sum, slot.desc.inputKind, HashAggrJitValueKind::Int128) - : sum, - isEmpty}); - codegen.builder().CreateBr(continueBlock); - - codegen.builder().SetInsertPoint(continueBlock); + codegen.emitDecimalAddWithOverflow( + group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + // isEmpty = isEmpty && incomingIsEmpty. + auto* oldIsEmpty = + codegen.loadValue(group, b.getInt8Ty(), slot.offset + kIsEmptyOffset); + auto* bothEmpty = b.CreateAnd( + b.CreateICmpNE(oldIsEmpty, b.getInt8(0)), + b.CreateICmpNE(incomingIsEmpty, b.getInt8(0))); + codegen.storeValue( + group, + b.getInt8Ty(), + slot.offset + kIsEmptyOffset, + b.CreateZExt(bothEmpty, b.getInt8Ty())); + b.CreateBr(continueBlock); + + b.SetInsertPoint(continueBlock); } bool canCompileDecimalSumExtract(const HashAggrJitSlot&, bool) { diff --git a/doc/hash-aggr-jit-todolist.md b/doc/hash-aggr-jit-todolist.md index c498aa0ca..027f9bf48 100644 --- a/doc/hash-aggr-jit-todolist.md +++ b/doc/hash-aggr-jit-todolist.md @@ -93,6 +93,50 @@ ## Pending +### [P0] JIT add/merge+extract 路径正确性 bug,被 test 链接丢符号长期掩盖 + +**现象** +- 单测 `SumAggregationTest.hashAggrJitMergeAndExtract` 与 + `SumAggregationTest.hashAggrJitAllNullGroup`(均为 partial+final 两阶段、非 decimal) + 在 JIT 路径**真正执行**时结果错误: + - `hashAggrJitAllNullGroup`:group sum 期望 12,得 0。 + - `hashAggrJitMergeAndExtract`:sum/avg/min 全 null、count 全 0,相当于 add 完全没生效。 +- JIT 模块成功编译执行(无 "Symbols not found" / 无 fallback 日志),是**执行结果错**, + 不是回退。 + +**根因定位(已用 git stash 二分确认)** +- 与本轮 decimal IR 化改动**无关**:在干净 HEAD 上、仅加一个把 runtime 符号 + (如 `jit_HashAggrResizeVector`)拉进 test 可执行的 link anchor,这两个用例即 FAIL。 +- 真正背景:commit `4cbfc5e590`(runtime helper 迁出 `RowContainer.cpp` 到独立 .o)后, + 这些 `jit_HashAggr*` 符号**未被 test/可执行链接**(无 C++ 引用,.o 被链接器丢弃)。 + 于是 JIT 在 test 二进制里 materialize 失败 → **静默回退非 JIT** → 结果恰好正确 → + **掩盖了 JIT 路径本身的既有正确性 bug**。 +- 本轮 decimal 改动新增的 link anchor 把这些符号拉回可执行,JIT 路径终于被真正执行, + 从而**暴露**(非引入)该 bug。 + +**潜在影响(需进一步确认)** +- 若生产可执行同样没有引用这些 runtime .o,则 HashAggr JIT 在生产里可能**根本没在跑** + (一直静默回退非 JIT)。需要核实生产链接是否包含这些符号。 +- 一旦修复链接(让 JIT 真正执行),这个 add/merge+extract 正确性 bug 会立刻显现, + 必须在「启用 JIT 执行」之前先修。 + +**后续待办** +- 定位 add/merge+extract 在两阶段非 decimal 场景下结果归零/全 null 的根因 + (疑点:partial extract 与 final merge 的累加器布局 / null 语义,可能与 + commit `f74cc21160` 删除 `numNulls_` 同步相关——`allNullGroup` 正是该语义守护用例)。 +- 当前 decimal 改动保留了 link anchor(benchmark 需要它,否则 JIT 符号缺失); + 注意 anchor 会让上述 bug 在跑相关单测时显现为 FAILED。 + +**⚠️ 合入注意** +- 本轮 decimal IR 改动保留了 link anchor,启用后 JIT 路径会真正执行,导致 + `hashAggrJitMergeAndExtract` / `hashAggrJitAllNullGroup` 两个单测**变红(FAILED)**。 +- 这不是 decimal 改动引入的回归,而是上述既有 bug 被暴露;但**合入前必须先修该 P0 bug, + 否则 CI 会红**。两个选项: + 1. 先修 add/merge+extract 正确性 bug,再合入(推荐)。 + 2. 临时移除 link anchor —— 但那样 benchmark 里 JIT 符号又会解析失败、JIT 回退, + decimal 性能改善无法体现。 +- 简言之:**link anchor + 既有 bug 是绑定的**,要么一起修好,要么都先不动。 + ### [P2] chunk 同时 codegen `add_dense` 和 `add_dense_no_null`,编译时间与产物 ×2 **现状** diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 72a7bfaff..30e0f19ac 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -1095,3 +1095,47 @@ P2 6.20 / 6.31 / 6.25ms——互有高低,落在噪声范围内。 RowContainer 中的非连续布局),而非已被预取覆盖的 `indices[row]` 间接寻址。 - P3(下沉 per-row accumulator null clear)的待确认正确性约束(新组创建与首次更新 是否同 batch)经评估不成立、争议较大,暂缓,不在本轮实施。 + +--- + +## 13. Decimal sum/avg add/merge 纯 IR 化 + +### 13.1 背景 + +decimal sum/avg 的 add/merge 主路径此前不是真正的 inline IR:每行通过 +`CreateCall(jit_HashAggrUpdate/MergeDecimal*)` 把 i128 加法 + 溢出检测转交 C++ +runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路由,真正 +算子在跨函数调用里执行——付出了 LLVM 的代价却没拿到 inline 红利。 + +### 13.2 改动 + +- 新增 `HashAggrJitCodegen::emitDecimalAddWithOverflow`:纯 IR 实现 i128 + `CreateAdd` + 溢出检测(`(a>0&&b>0&&r<0)||(a<0&&b<0&&r>=0)`,≤8 条 IR), + 溢出计数用 `posOverflow - negOverflow` 累加。 +- `DecimalSumOps` / `DecimalAvgOps` 的 init/add-raw/add-merge 全部改为纯 IR: + - init:直接 store sum/overflow/(count|isEmpty),替代 `jit_HashAggrInitDecimal*`。 + - add/merge:`emitDecimalAddWithOverflow` + IR 内 `++count` / `isEmpty &&=`。 + - state 字段访问用 `offsetof(JitDecimal*State, field)` 派生 offset,避免硬编码。 +- 删除不再被调用的 `jit_HashAggrInit/Update/MergeDecimal*` runtime helper 及其 + builtin 声明、`jitHashAggrAddWithOverflow`。 +- per-row 的跨函数调用从 N 次降为 0(add 主路径全部内联到循环体)。 + +### 13.3 性能(width8,bm_min_iters=50) + +| case | 改前 jit | 改后 jit | nojit(参考) | 改善 | +|------|----------|----------|----------------|------| +| width8_decimal_sum | 9.86ms | **9.01ms** | 11.79ms | ~9% | +| width8_decimal_avg | 14.75ms | **13.88ms** | 16.72ms | ~6% | + +- nojit 基线基本不变,说明提升来自 JIT 侧 add/merge 内联,而非环境波动。 +- 多轮测量 decimal_sum_jit 稳定在 8.1–9.0ms 区间(取决于机器负载),均优于改前。 +- 收益幅度小于「翻倍」的乐观预期:i128 算术本身有成本,且热循环还有 group + 寻址 / null 处理开销,per-row call 的消除只压缩了其中一部分。 + +### 13.4 正确性 + +- decimal 专项单测全部通过:`decimalSum` / `decimalGlobalSumOverflow` / + `decimalGroupBySumOverflow` / `decimalLargeCountRowsOverflow` / + `decimalSomeGroupsAllnullValues`(覆盖溢出、全 null 组等关键路径)。 +- extract 的 decimal 计算(依赖 `DecimalUtil` 精度判定、每组一次、非热路径) + 保留 runtime helper,不在本次范围。 From cb8e15a230dd0055a0558804d50875f1c048858b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 12:45:58 +0800 Subject: [PATCH 36/98] remove helper function jit_HashAggrSetPartialAvgDouble --- bolt/exec/GroupingSet.cpp | 19 ++++++-- bolt/jit/aggregation/HashAggrJit.cpp | 41 ++++------------ .../aggregation/runtime/HashAggrRuntime.cpp | 20 -------- doc/hashaggr-jit-benchmark.md | 48 +++++++++++++++++++ 4 files changed, 72 insertions(+), 56 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 557abb906..d5ea6f620 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -145,18 +145,22 @@ void fillHashAggrJitRowFieldInputs( input.rowField1Nulls = countVector->rawNulls(); } -void fillHashAggrJitPartialAvgOutput( +// Fills the raw flat sum/count field pointers for a partial avg ROW output. +// Returns false when the ROW children are not both FLAT (e.g. dictionary/ +// constant wrapped), in which case the JIT fast path is not applicable and the +// caller must fall back to the non-JIT extract path. +bool fillHashAggrJitPartialAvgOutput( jit::HashAggrJitOutput& output, BaseVector* vector) { auto* rowVector = vector->asUnchecked(); if (rowVector->childrenSize() < 2) { - return; + return false; } auto& sumVector = rowVector->childAt(0); auto& countVector = rowVector->childAt(1); if (sumVector->encoding() != VectorEncoding::Simple::FLAT || countVector->encoding() != VectorEncoding::Simple::FLAT) { - return; + return false; } output.rowField0Values = sumVector->asUnchecked>()->mutableRawValues(); @@ -164,6 +168,7 @@ void fillHashAggrJitPartialAvgOutput( output.rowField1Values = countVector->asUnchecked>()->mutableRawValues(); output.rowField1Nulls = countVector->mutableRawNulls(); + return true; } std::string hashAggrJitSlotDebugString( @@ -1237,8 +1242,12 @@ void GroupingSet::runHashAggrJitExtractChunks( } else if (aggregateVector->encoding() == VectorEncoding::Simple::ROW && slot.desc.kind == jit::HashAggrJitKind::Avg) { hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); - fillHashAggrJitPartialAvgOutput( - hashAggrJitOutputs_[slotIndex], aggregateVector.get()); + if (!fillHashAggrJitPartialAvgOutput( + hashAggrJitOutputs_[slotIndex], aggregateVector.get())) { + canRunChunk = false; + skipReason = "partial avg row fields are not flat"; + break; + } } hashAggrJitResultPtrs_[slotIndex] = reinterpret_cast(&hashAggrJitOutputs_[slotIndex]); diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 9db6015db..8d2c44c84 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -155,11 +155,6 @@ void ensureBuiltinDeclarations(llvm::Module& module) { declareFunction(module, "jit_HashAggrSetFlatI64", voidTy, {i8PtrTy, i32Ty, i64Ty, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatFloat", voidTy, {i8PtrTy, i32Ty, floatTy, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatDouble", voidTy, {i8PtrTy, i32Ty, doubleTy, i8Ty}); - declareFunction( - module, - "jit_HashAggrSetPartialAvgDouble", - voidTy, - {i8PtrTy, i32Ty, doubleTy, i64Ty, i8Ty}); // Decimal extract helpers: (vector, row, group, offset, precision, scale, // longDecimal). declareFunction( @@ -631,7 +626,8 @@ llvm::Value* HashAggrJitCodegen::loadDecodedRowField( !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); return builder().CreateCall( - module_.getFunction(name), {decodedVector, row, builder().getInt32(field)}); + module_.getFunction(name), + {decodedVector, row, builder().getInt32(field)}); } llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( @@ -695,7 +691,8 @@ void HashAggrJitCodegen::emitFlatValue( } auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( - module_.getFunction(setter), {vector, row, value, isNull}); + module_.getFunction(setter), + {vector, row, value, isNull}); } void HashAggrJitCodegen::resizeResultVector( @@ -703,7 +700,8 @@ void HashAggrJitCodegen::resizeResultVector( llvm::Value* size) const { auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( - module_.getFunction("jit_HashAggrResizeVector"), {vector, size}); + module_.getFunction("jit_HashAggrResizeVector"), + {vector, size}); } void HashAggrJitCodegen::emitPartialAvgResult( @@ -712,21 +710,12 @@ void HashAggrJitCodegen::emitPartialAvgResult( llvm::Value* sum, llvm::Value* count, llvm::Value* isNull) const { + // The extract admission path (runHashAggrJitExtractChunks) guarantees the + // partial avg ROW output has flat sum/count children before the chunk runs, + // so rowField0/1 values are always populated and we can write them directly + // without a runtime fast/helper branch. auto* sumValues = ::bytedance::bolt::jit::loadOutputRowFieldPointer( builder(), output, 0, false); - auto* hasRawRowOutput = builder().CreateICmpNE( - sumValues, - llvm::ConstantPointerNull::get( - llvm::PointerType::get(builder().getContext(), 0))); - auto* fastBlock = llvm::BasicBlock::Create( - module_.getContext(), "partial_avg_raw", builder().GetInsertBlock()->getParent()); - auto* helperBlock = llvm::BasicBlock::Create( - module_.getContext(), "partial_avg_helper", builder().GetInsertBlock()->getParent()); - auto* doneBlock = llvm::BasicBlock::Create( - module_.getContext(), "partial_avg_done", builder().GetInsertBlock()->getParent()); - builder().CreateCondBr(hasRawRowOutput, fastBlock, helperBlock); - - builder().SetInsertPoint(fastBlock); auto* sumTypedValues = builder().CreatePointerCast(sumValues, builder().getDoubleTy()->getPointerTo()); auto* countValues = ::bytedance::bolt::jit::loadOutputRowFieldPointer( @@ -742,16 +731,6 @@ void HashAggrJitCodegen::emitPartialAvgResult( countStore->setAlignment(llvm::Align(1)); auto* nulls = ::bytedance::bolt::jit::loadOutputNulls(builder(), output); ::bytedance::bolt::jit::emitOutputNullBit(builder(), nulls, row, isNull); - builder().CreateBr(doneBlock); - - builder().SetInsertPoint(helperBlock); - auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); - builder().CreateCall( - module_.getFunction("jit_HashAggrSetPartialAvgDouble"), - {vector, row, sum, count, isNull}); - builder().CreateBr(doneBlock); - - builder().SetInsertPoint(doneBlock); } void HashAggrJitCodegen::emitDecimalSumExtract( diff --git a/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp index 457a9f1ff..c1b5d7e3f 100644 --- a/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp +++ b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp @@ -10,7 +10,6 @@ // They were previously colocated in RowContainer.cpp purely because the // jit_GetDecodedValue* helpers already lived there. -#include "bolt/vector/ComplexVector.h" #include "bolt/vector/FlatVector.h" extern "C" { @@ -81,23 +80,4 @@ __attribute__((__visibility__("default"))) void jit_HashAggrSetFlatDouble( isNull ? flat->setNull(row, true) : flat->set(row, value); } -__attribute__((__visibility__("default"))) void jit_HashAggrSetPartialAvgDouble( - char* vector, - int32_t row, - double sum, - int64_t count, - int8_t isNull) { - auto* rowVector = reinterpret_cast(vector) - ->as(); - auto* sumVector = rowVector->childAt(0)->asFlatVector(); - auto* countVector = rowVector->childAt(1)->asFlatVector(); - if (isNull) { - rowVector->setNull(row, true); - return; - } - rowVector->setNull(row, false); - sumVector->set(row, sum); - countVector->set(row, count); -} - } // extern "C" diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 30e0f19ac..57cdb6762 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -1139,3 +1139,51 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 `decimalSomeGroupsAllnullValues`(覆盖溢出、全 null 组等关键路径)。 - extract 的 decimal 计算(依赖 `DecimalUtil` 精度判定、每组一次、非热路径) 保留 runtime helper,不在本次范围。 + +--- + +## 14. partial avg extract 去掉运行时 fast/helper 分支 + +### 14.1 背景 + +`emitPartialAvgResult` 此前在 IR 里有 `hasRawRowOutput ? fast : helper` 的运行时 +分支(3 个 BasicBlock + 1 条件跳转):当 partial avg 输出 ROW 的 sum/count 子字段 +为 FLAT 时走直写 fast 路径,否则回退 `jit_HashAggrSetPartialAvgDouble` helper。 +但该分支判定的是**循环不变量**(`rowField0Values` 在整个 extract 调用内不变)。 + +### 14.2 改动 + +- 把 fast/helper 的选择从「运行时」前移到「extract 准入」: + `fillHashAggrJitPartialAvgOutput` 改为返回 bool,当 ROW 子字段非 FLAT + (dictionary/constant 包装)时返回 false;`runHashAggrJitExtractChunks` 据此 + 令 `canRunChunk=false`、回退非 JIT 并打 VLOG(`skipReason="partial avg row + fields are not flat"`)。 +- 这样保证进入 JIT 的 chunk 其 rowField0/1 必被填充,IR 里直接走纯 fast 路径。 +- `emitPartialAvgResult` 删除运行时分支与 3 个 BasicBlock;删除不再被调用的 + `jit_HashAggrSetPartialAvgDouble` runtime helper、builtin 声明及其 + `ComplexVector.h` include。 + +### 14.3 性能(bm_min_iters=50,基线=分支版,优化=纯 fast) + +| case | 基线 | 优化后(2 轮) | 变化 | +|------|------|--------------|------| +| width8_avg_jit | 4.26ms | 4.15 / 4.21ms | ~持平–3% | +| width16_avg_jit | 8.75ms | 7.55 / 7.58ms | ~14% | +| width8_merge_avg_jit | 6.22ms | 5.75 / 6.17ms | 波动,约 0–8% | +| width16_merge_avg_jit | 11.44ms | 10.85 / 10.70ms | ~5–6% | +| width8_high_card_partial_avg_extract_jit | 74.13ms | 70.00 / 68.84ms | ~6–7% | + +- 整体小幅改善或持平,无回归。改善幅度有限且部分用例有运行间波动——符合预期: + 被删的分支是循环不变量,LLVM LICM + 分支预测本就覆盖了大部分开销,去掉它主要 + 减少了 codegen 出的 BasicBlock 数与少量恒命中的比较/跳转。 +- 价值更多在**正确性与可维护性**:把「子字段非 FLAT」从 IR 兜底分支收敛为 plan + 阶段的显式准入回退,IR 不再生成永远走同一侧的运行时分叉。 + +### 14.4 正确性 + +- partial avg / average 相关单测通过:`hashAggrJitPartialAvgExtractAccumulators` + (直接覆盖本次 fast 路径)、`avgDecimal` / `avgAllNulls` / + `rowBasedSpillDecimalAvg` / `hashAggrJitDecimalSumAndFloatingMinMax` / + `hashAggrJitSplitsContiguousSegments`。 +- 无新增回归(`hashAggrJitMergeAndExtract` / `hashAggrJitAllNullGroup` 仍 FAIL, + 系既有 P0 bug,见 todolist,与本次无关)。 From 804fac91457bb789d537ff9cb549a43f83c02c69 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 13:05:32 +0800 Subject: [PATCH 37/98] minor refactor to improve HashAggrJitChunk --- bolt/exec/GroupingSet.cpp | 6 +++--- bolt/exec/GroupingSet.h | 2 ++ bolt/jit/aggregation/HashAggrJit.cpp | 25 ++++++++++++------------- bolt/jit/aggregation/HashAggrJit.h | 9 ++------- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index d5ea6f620..a5a28364d 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -211,7 +211,7 @@ std::string hashAggrJitChunkDebugString( out << hashAggrJitSlotDebugString(slot, &aggregates[slot.aggregateIndex]); } out << "] canExtract=" << chunk.canExtract() - << " enabled=" << chunk.enabled(); + << " codegenReady=" << chunk.isCodegenReady(); return out.str(); } #endif @@ -1084,8 +1084,8 @@ void GroupingSet::runHashAggrJitChunks( jitExecuted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { - if (!chunk.enabled()) { - VLOG(1) << "HashAggrJit chunk disabled, skip add: " + if (!chunk.isCodegenReady()) { + VLOG(1) << "HashAggrJit chunk is not codegen-ready, skip add: " << hashAggrJitChunkDebugString(chunk, aggregates_); continue; } diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index c38b143d7..7c47782e7 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -465,6 +465,8 @@ class GroupingSet { std::vector hashAggrJitChunks_; std::vector hashAggrJitDecoded_; std::vector hashAggrJitDecodedInputs_; + // Keeps input vectors alive for the DecodedVector buffers referenced by + // JIT during addDense. std::vector hashAggrJitInputVectors_; std::vector hashAggrJitDecodedPtrs_; std::vector hashAggrJitNewGroups_; diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 8d2c44c84..27f74bf6e 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -854,7 +854,7 @@ bool genInitIR( for (const auto& slot : slots) { if (slot.desc.ops == nullptr || slot.desc.ops->initGroup == nullptr) { - return true; + return false; } slot.desc.ops->initGroup(codegen, group, slot); } @@ -866,7 +866,7 @@ bool genInitIR( builder.SetInsertPoint(end); builder.CreateRetVoid(); - return llvm::verifyFunction(*func, &llvm::errs()); + return !llvm::verifyFunction(*func, &llvm::errs()); } bool genAddDenseIR( @@ -927,12 +927,12 @@ bool genAddDenseIR( builder.SetInsertPoint(updateBlock); if (slot.desc.ops == nullptr) { - return true; + return false; } auto* addFn = slot.desc.mergeInput ? slot.desc.ops->addIntermediateResults : slot.desc.ops->addRawInput; if (addFn == nullptr) { - return true; + return false; } addFn(codegen, group, decoded, row, slot, checkInputNulls, nextBlock); builder.CreateBr(nextBlock); @@ -946,7 +946,7 @@ bool genAddDenseIR( builder.SetInsertPoint(end); builder.CreateRetVoid(); - return llvm::verifyFunction(*func, &llvm::errs()); + return !llvm::verifyFunction(*func, &llvm::errs()); } std::string setFlatValueFunction(HashAggrJitValueKind kind) { @@ -1025,7 +1025,7 @@ bool genExtractIR( auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); if (slot.desc.ops->extract == nullptr) { - return true; + return false; } slot.desc.ops->extract( codegen, group, slot, HashAggrJitExtractTarget{vector, row, partialOutput}); @@ -1038,7 +1038,7 @@ bool genExtractIR( builder.SetInsertPoint(end); builder.CreateRetVoid(); - return llvm::verifyFunction(*func, &llvm::errs()); + return !llvm::verifyFunction(*func, &llvm::errs()); } } // namespace @@ -1135,7 +1135,7 @@ std::string HashAggrJitChunk::functionName() const { } bool HashAggrJitChunk::canExtract() const { - if (extract_ == nullptr || disabled_) { + if (extract_ == nullptr) { return false; } for (const auto& slot : slots_) { @@ -1174,10 +1174,11 @@ bool HashAggrJitChunk::codegen() { const auto extractFn = extractFunctionName(); module_ = jit->CompileModule( [&](llvm::Module& module) { - const bool hasError = genInitIR(module, initFn, slots_) || - genAddDenseIR(module, addFn, slots_, true) || - genAddDenseIR(module, addNoNullFn, slots_, false) || + const bool ok = genInitIR(module, initFn, slots_) && + genAddDenseIR(module, addFn, slots_, true) && + genAddDenseIR(module, addNoNullFn, slots_, false) && genExtractIR(module, extractFn, slots_, partialOutput_); + const bool hasError = !ok; logHashAggrJitFunctionIR(module, moduleKey, initFn, "init", hasError); logHashAggrJitFunctionIR(module, moduleKey, addFn, "add_dense", hasError); logHashAggrJitFunctionIR( @@ -1192,7 +1193,6 @@ bool HashAggrJitChunk::codegen() { }, moduleKey); if (!module_) { - disabled_ = true; return false; } init_ = reinterpret_cast(module_->getFuncPtr(initFn)); @@ -1202,7 +1202,6 @@ bool HashAggrJitChunk::codegen() { extract_ = reinterpret_cast(module_->getFuncPtr(extractFn)); if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr || extract_ == nullptr) { - disabled_ = true; return false; } return true; diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 7feb2d313..c3d2d750b 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -161,16 +161,12 @@ class HashAggrJitChunk { bool codegen(); - bool enabled() const { - return addDense_ != nullptr && !disabled_; + bool isCodegenReady() const { + return addDense_ != nullptr; } bool canExtract() const; - void disable() { - disabled_ = true; - } - void init(char** newGroups, int32_t numNewGroups) const { init_(newGroups, numNewGroups); } @@ -208,7 +204,6 @@ class HashAggrJitChunk { HashAggrJitAddDenseFunc addDense_{nullptr}; HashAggrJitAddDenseFunc addDenseNoNull_{nullptr}; HashAggrJitExtractFunc extract_{nullptr}; - bool disabled_{false}; }; } // namespace bytedance::bolt::jit From 7ef69cd175afd21694a0e826a97daba6dbfb95a7 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 14:20:50 +0800 Subject: [PATCH 38/98] cache function name in HashAggrJitChunk --- bolt/jit/aggregation/HashAggrJit.cpp | 56 +++++------- bolt/jit/aggregation/HashAggrJit.h | 24 +++++- doc/hashaggr-jit-benchmark.md | 124 +++++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 37 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 27f74bf6e..0c93f9a9a 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1046,7 +1046,24 @@ bool genExtractIR( HashAggrJitChunk::HashAggrJitChunk( std::vector slots, bool partialOutput) - : slots_(std::move(slots)), partialOutput_(partialOutput) {} + : slots_(std::move(slots)), partialOutput_(partialOutput) { + std::ostringstream out; + out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" + << slots_.size(); + for (const auto& slot : slots_) { + out << "_" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "unknown") << "_" + << static_cast(slot.desc.kind) << hashAggrJitValueKindName(slot.desc.inputKind) + << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" << slot.offset + << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) + << (slot.desc.countStar ? "s" : "x") << (slot.desc.mergeInput ? "g" : "r") + << (slot.desc.decimal ? "d" : "n"); + } + functionName_ = out.str(); + initFunctionName_ = functionName_ + "_init"; + addDenseFunctionName_ = functionName_ + "_add_dense"; + addDenseNoNullFunctionName_ = functionName_ + "_add_dense_no_null"; + extractFunctionName_ = functionName_ + "_extract"; +} std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { switch (kind) { @@ -1119,21 +1136,6 @@ std::string HashAggrJitDescriptor::signature() const { decimal); } -std::string HashAggrJitChunk::functionName() const { - std::ostringstream out; - out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" - << slots_.size(); - for (const auto& slot : slots_) { - out << "_" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "unknown") << "_" - << static_cast(slot.desc.kind) << hashAggrJitValueKindName(slot.desc.inputKind) - << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" << slot.offset - << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) - << (slot.desc.countStar ? "s" : "x") << (slot.desc.mergeInput ? "g" : "r") - << (slot.desc.decimal ? "d" : "n"); - } - return out.str(); -} - bool HashAggrJitChunk::canExtract() const { if (extract_ == nullptr) { return false; @@ -1147,18 +1149,6 @@ bool HashAggrJitChunk::canExtract() const { return true; } -std::string HashAggrJitChunk::initFunctionName() const { - return functionName() + "_init"; -} - -std::string HashAggrJitChunk::addDenseNoNullFunctionName() const { - return functionName() + "_add_dense_no_null"; -} - -std::string HashAggrJitChunk::extractFunctionName() const { - return functionName() + "_extract"; -} - bool HashAggrJitChunk::codegen() { if (addDense_) { return true; @@ -1167,11 +1157,11 @@ bool HashAggrJitChunk::codegen() { if (jit == nullptr) { return false; } - const auto moduleKey = functionName(); - const auto initFn = initFunctionName(); - const auto addFn = moduleKey + "_add_dense"; - const auto addNoNullFn = addDenseNoNullFunctionName(); - const auto extractFn = extractFunctionName(); + const auto& moduleKey = functionName_; + const auto& initFn = initFunctionName_; + const auto& addFn = addDenseFunctionName_; + const auto& addNoNullFn = addDenseNoNullFunctionName_; + const auto& extractFn = extractFunctionName_; module_ = jit->CompileModule( [&](llvm::Module& module) { const bool ok = genInitIR(module, initFn, slots_) && diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index c3d2d750b..a6316c247 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -191,14 +191,30 @@ class HashAggrJitChunk { return slots_; } - std::string functionName() const; - std::string initFunctionName() const; - std::string addDenseNoNullFunctionName() const; - std::string extractFunctionName() const; + const std::string& functionName() const { + return functionName_; + } + const std::string& initFunctionName() const { + return initFunctionName_; + } + const std::string& addDenseFunctionName() const { + return addDenseFunctionName_; + } + const std::string& addDenseNoNullFunctionName() const { + return addDenseNoNullFunctionName_; + } + const std::string& extractFunctionName() const { + return extractFunctionName_; + } private: std::vector slots_; bool partialOutput_{false}; + std::string functionName_; + std::string initFunctionName_; + std::string addDenseFunctionName_; + std::string addDenseNoNullFunctionName_; + std::string extractFunctionName_; CompiledModuleSP module_; HashAggrJitInitFunc init_{nullptr}; HashAggrJitAddDenseFunc addDense_{nullptr}; diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 57cdb6762..44050b206 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -1179,6 +1179,7 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 - 价值更多在**正确性与可维护性**:把「子字段非 FLAT」从 IR 兜底分支收敛为 plan 阶段的显式准入回退,IR 不再生成永远走同一侧的运行时分叉。 + ### 14.4 正确性 - partial avg / average 相关单测通过:`hashAggrJitPartialAvgExtractAccumulators` @@ -1187,3 +1188,126 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 `hashAggrJitSplitsContiguousSegments`。 - 无新增回归(`hashAggrJitMergeAndExtract` / `hashAggrJitAllNullGroup` 仍 FAIL, 系既有 P0 bug,见 todolist,与本次无关)。 + +### 14.5 当前性能 +``` +============================================================================ +[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s +============================================================================ +width4_sum_nojit 2.62ms 382.13 +width4_sum_jit 2.37ms 422.76 +---------------------------------------------------------------------------- +width4_avg_nojit 3.14ms 318.68 +width4_avg_jit 2.61ms 383.32 +---------------------------------------------------------------------------- +width4_min_nojit 2.42ms 413.02 +width4_min_jit 2.41ms 415.43 +---------------------------------------------------------------------------- +width4_count_nojit 2.42ms 413.14 +width4_count_jit 1.94ms 515.97 +---------------------------------------------------------------------------- +width4_merge_sum_nojit 3.59ms 278.76 +width4_merge_sum_jit 3.36ms 297.54 +---------------------------------------------------------------------------- +width4_merge_avg_nojit 4.73ms 211.27 +width4_merge_avg_jit 3.68ms 271.39 +---------------------------------------------------------------------------- +width4_merge_min_nojit 3.49ms 286.45 +width4_merge_min_jit 3.42ms 292.31 +---------------------------------------------------------------------------- +width4_merge_count_nojit 3.41ms 293.08 +width4_merge_count_jit 2.89ms 345.61 +---------------------------------------------------------------------------- +width8_sum_nojit 4.77ms 209.45 +width8_sum_jit 3.39ms 294.95 +---------------------------------------------------------------------------- +width8_avg_nojit 5.33ms 187.66 +width8_avg_jit 4.08ms 244.82 +---------------------------------------------------------------------------- +width8_min_nojit 3.71ms 269.36 +width8_min_jit 3.56ms 280.86 +---------------------------------------------------------------------------- +width8_count_nojit 4.27ms 233.96 +width8_count_jit 2.33ms 429.47 +---------------------------------------------------------------------------- +width8_merge_sum_nojit 6.26ms 159.83 +width8_merge_sum_jit 4.77ms 209.84 +---------------------------------------------------------------------------- +width8_merge_avg_nojit 7.37ms 135.72 +width8_merge_avg_jit 5.57ms 179.63 +---------------------------------------------------------------------------- +width8_merge_min_nojit 5.27ms 189.67 +width8_merge_min_jit 4.80ms 208.55 +---------------------------------------------------------------------------- +width8_merge_count_nojit 5.81ms 172.20 +width8_merge_count_jit 3.49ms 286.31 +---------------------------------------------------------------------------- +width16_sum_nojit 8.69ms 115.13 +width16_sum_jit 6.10ms 163.95 +---------------------------------------------------------------------------- +width16_avg_nojit 10.47ms 95.49 +width16_avg_jit 7.26ms 137.82 +---------------------------------------------------------------------------- +width16_min_nojit 7.67ms 130.38 +width16_min_jit 6.29ms 158.91 +---------------------------------------------------------------------------- +width16_count_nojit 7.73ms 129.35 +width16_count_jit 3.54ms 282.54 +---------------------------------------------------------------------------- +width16_merge_sum_nojit 11.25ms 88.91 +width16_merge_sum_jit 7.42ms 134.78 +---------------------------------------------------------------------------- +width16_merge_avg_nojit 14.00ms 71.45 +width16_merge_avg_jit 9.94ms 100.61 +---------------------------------------------------------------------------- +width16_merge_min_nojit 10.03ms 99.69 +width16_merge_min_jit 8.16ms 122.54 +---------------------------------------------------------------------------- +width16_merge_count_nojit 9.83ms 101.69 +width16_merge_count_jit 4.95ms 202.04 +---------------------------------------------------------------------------- +width32_sum_nojit 17.00ms 58.83 +width32_sum_jit 12.28ms 81.46 +---------------------------------------------------------------------------- +width32_avg_nojit 19.46ms 51.38 +width32_avg_jit 15.26ms 65.52 +---------------------------------------------------------------------------- +width32_min_nojit 15.13ms 66.10 +width32_min_jit 12.11ms 82.55 +---------------------------------------------------------------------------- +width32_count_nojit 15.48ms 64.60 +width32_count_jit 6.94ms 144.09 +---------------------------------------------------------------------------- +width32_merge_sum_nojit 22.55ms 44.34 +width32_merge_sum_jit 16.11ms 62.08 +---------------------------------------------------------------------------- +width32_merge_avg_nojit 27.31ms 36.62 +width32_merge_avg_jit 21.04ms 47.52 +---------------------------------------------------------------------------- +width32_merge_min_nojit 20.48ms 48.82 +width32_merge_min_jit 16.56ms 60.38 +---------------------------------------------------------------------------- +width32_merge_count_nojit 19.85ms 50.39 +width32_merge_count_jit 10.45ms 95.69 +---------------------------------------------------------------------------- +width8_decimal_sum_nojit 11.78ms 84.91 +width8_decimal_sum_jit 8.10ms 123.47 +---------------------------------------------------------------------------- +width8_decimal_avg_nojit 16.08ms 62.21 +width8_decimal_avg_jit 12.94ms 77.25 +---------------------------------------------------------------------------- +width8_double_min_nojit 5.03ms 198.86 +width8_double_min_jit 4.16ms 240.33 +---------------------------------------------------------------------------- +width8_double_max_nojit 4.17ms 239.71 +width8_double_max_jit 3.94ms 253.99 +---------------------------------------------------------------------------- +width8_high_card_partial_avg_extract_nojit 57.83ms 17.29 +width8_high_card_partial_avg_extract_jit 61.68ms 16.21 +---------------------------------------------------------------------------- +width8_high_card_partial_sum_extract_nojit 25.24ms 39.62 +width8_high_card_partial_sum_extract_jit 22.27ms 44.90 +---------------------------------------------------------------------------- +``` + + From e6935837c3d14613d27b1f1223b50afdba18b0f1 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 14:33:42 +0800 Subject: [PATCH 39/98] remove vlog in hot path --- bolt/exec/GroupingSet.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index a5a28364d..92ec2a533 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1164,8 +1164,6 @@ void GroupingSet::runHashAggrJitChunks( hashAggrJitNewGroups_[i] = groups[newGroups[i]]; } chunk.init(hashAggrJitNewGroups_.data(), newGroups.size()); - VLOG(1) << "HashAggrJit initialized new groups for chunk " - << chunk.functionName() << " newGroups=" << newGroups.size(); } chunk.addDense( @@ -1173,15 +1171,8 @@ void GroupingSet::runHashAggrJitChunks( activeRows_.end(), hashAggrJitDecodedPtrs_.data(), inputsMayHaveNulls); - VLOG(1) << "HashAggrJit add executed: chunk=" << chunk.functionName() - << " rows=" << activeRows_.end() - << " inputsMayHaveNulls=" << inputsMayHaveNulls - << " slots=" << hashAggrJitChunkDebugString(chunk, aggregates_); for (const auto& slot : chunk.slots()) { jitExecuted[slot.aggregateIndex] = 1; - VLOG(1) << "HashAggrJit slot executed in add path: " - << hashAggrJitSlotDebugString( - slot, &aggregates_[slot.aggregateIndex]); } } } @@ -1259,14 +1250,8 @@ void GroupingSet::runHashAggrJitExtractChunks( continue; } chunk.extract(groups.data(), groups.size(), hashAggrJitResultPtrs_.data()); - VLOG(1) << "HashAggrJit extract executed: chunk=" << chunk.functionName() - << " groups=" << groups.size() - << " slots=" << hashAggrJitChunkDebugString(chunk, aggregates_); for (const auto& slot : chunk.slots()) { jitExtracted[slot.aggregateIndex] = 1; - VLOG(1) << "HashAggrJit slot executed in extract path: " - << hashAggrJitSlotDebugString( - slot, &aggregates_[slot.aggregateIndex]); } } } From 28535ed874506f66a1f876b611808fcb22f53305 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 16:27:05 +0800 Subject: [PATCH 40/98] add more benchmark cases --- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 48 ++++++++++--- doc/hashaggr-jit-benchmark.md | 68 +++++++++++++++++-- 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 2e9a98cf3..02015d066 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -55,26 +55,24 @@ class HashAggrJitBenchmark : public VectorTestBase { counts.push_back(fmt::format("count(c{})", i + 1)); } - addCase(name + "_sum", rows, sums); - addCase(name + "_avg", rows, avgs); - addCase(name + "_min", rows, mins); - addCase(name + "_count", rows, counts); - addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); - addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); - addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); - addCase(name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); + addCase(name + "_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_count", rows, counts, AggregationPlanKind::PartialFinal); } void addDecimalBenchmark(const std::string& name, int32_t width) { auto rows = makeDecimalRows(width); std::vector sums; std::vector avgs; + sums.reserve(width); + avgs.reserve(width); for (auto i = 0; i < width; ++i) { sums.push_back(fmt::format("spark_sum(c{})", i + 1)); avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); } - addCase(name + "_decimal_sum", rows, sums); - addCase(name + "_decimal_avg", rows, avgs); + addCase(name + "_decimal_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_decimal_avg", rows, avgs, AggregationPlanKind::PartialFinal); } void addFloatingPointMinMaxBenchmark(const std::string& name, int32_t width) { @@ -103,6 +101,29 @@ class HashAggrJitBenchmark : public VectorTestBase { addCase(name + "_partial_sum_extract", rows, sums, AggregationPlanKind::Partial); } + void addHighCardinalityMergeBenchmark(const std::string& name, int32_t width) { + auto rows = makeHighCardinalityRows(width); + std::vector sums; + std::vector avgs; + std::vector mins; + std::vector counts; + sums.reserve(width); + avgs.reserve(width); + mins.reserve(width); + counts.reserve(width); + for (auto i = 0; i < width; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); + mins.push_back(fmt::format("min(c{})", i + 1)); + counts.push_back(fmt::format("count(c{})", i + 1)); + } + + addCase(name + "_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_count", rows, counts, AggregationPlanKind::PartialFinal); + } + private: std::vector makeRows(int32_t width) { std::vector names; @@ -273,7 +294,14 @@ int main(int argc, char** argv) { benchmark.addBenchmark("width8", 8); benchmark.addBenchmark("width16", 16); benchmark.addBenchmark("width32", 32); + benchmark.addHighCardinalityMergeBenchmark("width4_high_card", 4); + benchmark.addHighCardinalityMergeBenchmark("width8_high_card", 8); + benchmark.addHighCardinalityMergeBenchmark("width16_high_card", 16); + benchmark.addHighCardinalityMergeBenchmark("width32_high_card", 32); + benchmark.addDecimalBenchmark("width4", 4); benchmark.addDecimalBenchmark("width8", 8); + benchmark.addDecimalBenchmark("width16", 16); + benchmark.addDecimalBenchmark("width32", 32); benchmark.addFloatingPointMinMaxBenchmark("width8", 8); benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 44050b206..8f0fb5ff2 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -8,7 +8,8 @@ benchmark 不依赖它)。 - **benchmark**:`bolt/exec/benchmarks/HashAggrJitBenchmark.cpp`,目标 `bolt_hashaggr_jit_benchmark`。覆盖 sum/avg/min/count(width 4/8/16/32)、 - merge(partial+final)、decimal sum/avg、double min/max、partial extract。 + merge(partial+final)、decimal sum/avg(当前按 `PartialFinal` 路径评测)、 + double min/max、partial extract。 - **数据规模**:每用例 20 batch × 10000 行。 - **关键控制**: - JIT 模块为进程级 LRU 全局缓存,预热后**每个 JIT 函数仅编译一次**(已用 VLOG 验证 @@ -72,8 +73,9 @@ double_max 0.63x · partial_avg_extract 0.74x · partial_sum_extract 0.84x 几乎相同(如 width8_sum jit ≈ 5.0ms 两者一致),说明耗时与组数无关、只与行数相关—— 即**每行 add 成本** JIT 高于非 JIT 的向量化路径。这正是“低基数本应让 JIT 更受益”的 预期被反转的根本原因。 -4. **decimal_avg(0.75x) 优于 decimal_sum(0.40x)**:decimal_avg final 走非 JIT(spark - rescale 复杂逻辑),反而拖累较小,侧面印证当前 JIT 计算路径偏慢。 +4. **decimal_avg(0.75x) 曾优于 decimal_sum(0.40x)**:这组历史数据采集时,decimal_avg + final 仍走非 JIT(Spark rescale 复杂逻辑),因此拖累相对较小。当前已补齐 final + decimal avg extract JIT helper,新的 decimal_avg 结果需以 `PartialFinal` 基准重新观察。 ## 5. 结论与建议 @@ -1190,10 +1192,10 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 系既有 P0 bug,见 todolist,与本次无关)。 ### 14.5 当前性能 + + ``` ============================================================================ -[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s -============================================================================ width4_sum_nojit 2.62ms 382.13 width4_sum_jit 2.37ms 422.76 ---------------------------------------------------------------------------- @@ -1310,4 +1312,60 @@ width8_high_card_partial_sum_extract_jit 22.27ms 44.90 ---------------------------------------------------------------------------- ``` +## 15. decimal_avg final extract JIT 补齐说明 + +### 15.1 背景 + +在本轮修复前,decimal avg 只有 partial extract 走 JIT helper,final extract 仍留在 non-JIT 路径: + +- planner/codegen 侧通过 `canCompileDecimalAvgExtract(..., partialOutput)` 仅允许 partial path; +- runtime 侧 `jit_HashAggrExtractFinalDecimalAvg` 是空 stub,仅用于 link 成功; +- 因此历史上部分 `decimal_avg` benchmark 结果,实际测到的是“JIT add/merge + non-JIT final extract”的混合路径。 + +这也是第 4 章里“decimal_avg 曾优于 decimal_sum”的一个背景因素:当时 decimal avg 没有承担 final decimal +rescale/divide 的 JIT extract 成本。 + +### 15.2 本轮实现 + +本轮已补齐 final decimal avg extract 的 JIT 支持,策略是**继续保持 helper 模式**,不把 Spark decimal avg 的 +divide / overflow / precision-rescale 逻辑直接展开成 LLVM IR。 + +具体改动: + +1. **放开 codegen**:`decimal avg` 的 extract 现在 partial / final 都允许编译; +2. **扩展 helper ABI**:avg extract helper 额外接收最终结果 decimal 的 `resultPrecision/resultScale`; +3. **实现 final runtime helper**:在 runtime 中镜像 non-JIT `computeAvg` 语义: + - `adjustSumForOverflow` + - `divideWithRoundUp` + - `rescaleWithRoundUp` + - short / long decimal 分类型写回 `FlatVector` +4. **benchmark 口径更新**:`HashAggrJitBenchmark` 中 `decimal_sum/decimal_avg` 统一按 `PartialFinal` 路径评测, + 避免继续把 decimal avg 记成“只测 partial + 非 JIT final extract”的旧口径。 + +### 15.3 功能验证 + +本轮未新增一组完整 benchmark 数据表,但已完成功能与构建验证: + +- 构建通过:`bolt_thrustjit`、`bolt_exec`、`bolt_functions_spark_aggregates_test` +- Average 相关测试通过: + - `AverageAggregationTest.avgAllNulls` + - `AverageAggregationTest.avgDecimal` + - `AverageAggregationTest.avgDecimalWithMultipleRowVectors` + - `AverageAggregationTest.rowBasedSpillDecimalAvg` + +说明 final decimal avg extract JIT 至少已经满足当前 Spark avg 语义下的基础正确性要求: + +- `count == 0` 输出 null; +- sum overflow 无法修正时输出 null; +- divide / rescale overflow 时输出 null; +- short decimal 与 long decimal 结果类型都可写回。 + +### 15.4 对阅读本报告的影响 + +1. **第 2/3/4 章中的早期 decimal_avg 结论需要加注理解**:这些历史结论产生时,final decimal avg extract 还未走 JIT。 +2. **后续若继续比较 decimal_avg 的 JIT/no-JIT 收益,应以当前 `PartialFinal` benchmark 口径为准**。 +3. **当前 decimal_avg benchmark 的收益解释更完整**:它现在同时覆盖 JIT add、JIT merge 和 JIT final extract, + 比之前更接近真实生产路径。 +换句话说:从这一节之后,文档里关于 decimal avg 的性能讨论应默认理解为“**final extract JIT 已补齐**”的版本; +如果引用更早的数据,需要显式说明那是旧口径历史快照。 From 93dddddbf9a3a44c6dd72dfe017741def227e981 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 16:28:03 +0800 Subject: [PATCH 41/98] support decimal_avg final extract --- .../sparksql/aggregates/AverageAggregate.cpp | 5 +- bolt/jit/aggregation/HashAggrJit.cpp | 94 +++++++++++-- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 8 +- .../runtime/HashAggrDecimalRuntime.cpp | 130 ++++++++++++++++-- 4 files changed, 208 insertions(+), 29 deletions(-) diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 92ada29c6..70e46a962 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -177,8 +177,9 @@ class DecimalAverageAggregate : public DecimalAggregate { context.isRawInput ? context.inputType : context.inputType->childAt(0); const auto [sumPrecision, sumScale] = getDecimalPrecisionScale(*sumType_.get()); - const auto [resultPrecision, resultScale] = - getDecimalPrecisionScale(*this->resultType().get()); + const auto [resultPrecision, resultScale] = context.isPartialOutput + ? std::pair{0, 0} + : getDecimalPrecisionScale(*this->resultType().get()); return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::Avg, .inputKind = valueType->isShortDecimal() diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 0c93f9a9a..9a89731ef 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -155,8 +155,8 @@ void ensureBuiltinDeclarations(llvm::Module& module) { declareFunction(module, "jit_HashAggrSetFlatI64", voidTy, {i8PtrTy, i32Ty, i64Ty, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatFloat", voidTy, {i8PtrTy, i32Ty, floatTy, i8Ty}); declareFunction(module, "jit_HashAggrSetFlatDouble", voidTy, {i8PtrTy, i32Ty, doubleTy, i8Ty}); - // Decimal extract helpers: (vector, row, group, offset, precision, scale, - // longDecimal). + // Decimal extract helpers. + // Sum: (vector, row, group, offset, precision, scale, longDecimal). declareFunction( module, "jit_HashAggrExtractFinalDecimalSum", @@ -171,12 +171,12 @@ void ensureBuiltinDeclarations(llvm::Module& module) { module, "jit_HashAggrExtractFinalDecimalAvg", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty, i8Ty}); declareFunction( module, "jit_HashAggrExtractPartialDecimalAvg", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty, i8Ty}); } llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { @@ -612,13 +612,50 @@ llvm::Value* HashAggrJitCodegen::loadDecodedRowField( if (field == 0 || field == 1) { auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( builder(), decoded, field, false); + auto* hasRawValues = builder().CreateICmpNE( + rawValues, + llvm::ConstantPointerNull::get( + llvm::PointerType::get(builder().getContext(), 0))); + auto* fastBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_raw_load", + builder().GetInsertBlock()->getParent()); + auto* slowBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_helper_load", + builder().GetInsertBlock()->getParent()); + auto* doneBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_load_done", + builder().GetInsertBlock()->getParent()); + builder().CreateCondBr(hasRawValues, fastBlock, slowBlock); + + builder().SetInsertPoint(fastBlock); auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); auto* type = llvmType(kind); auto* typedValues = builder().CreatePointerCast(rawValues, type->getPointerTo()); auto* valueAddr = builder().CreateInBoundsGEP( type, typedValues, builder().CreateZExt(index, builder().getInt64Ty())); - auto* value = builder().CreateLoad(type, valueAddr); - value->setAlignment(llvm::Align(1)); + auto* fastValue = builder().CreateLoad(type, valueAddr); + fastValue->setAlignment(llvm::Align(1)); + builder().CreateBr(doneBlock); + auto* fastEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(slowBlock); + const auto name = decodedRowFieldFunction(kind); + BOLT_CHECK( + !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); + auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); + auto* slowValue = builder().CreateCall( + module_.getFunction(name), + {decodedVector, row, builder().getInt32(field)}); + builder().CreateBr(doneBlock); + auto* slowEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(doneBlock); + auto* value = builder().CreatePHI(llvmType(kind), 2, "row_field_value"); + value->addIncoming(fastValue, fastEnd); + value->addIncoming(slowValue, slowEnd); return value; } const auto name = decodedRowFieldFunction(kind); @@ -635,14 +672,31 @@ llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( llvm::Value* row, int32_t field) const { if (field == 0 || field == 1) { + auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( + builder(), decoded, field, false); auto* rawNulls = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( builder(), decoded, field, true); + auto* hasRawValues = builder().CreateICmpNE( + rawValues, + llvm::ConstantPointerNull::get( + llvm::PointerType::get(builder().getContext(), 0))); + auto* rawPathBlock = llvm::BasicBlock::Create( + module_.getContext(), "row_field_raw_null_path", builder().GetInsertBlock()->getParent()); + auto* helperPathBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_helper_null_path", + builder().GetInsertBlock()->getParent()); + auto* doneBlock = llvm::BasicBlock::Create( + module_.getContext(), "row_field_null_done", builder().GetInsertBlock()->getParent()); + builder().CreateCondBr(hasRawValues, rawPathBlock, helperPathBlock); + + builder().SetInsertPoint(rawPathBlock); auto* hasRawNulls = builder().CreateICmpNE( rawNulls, llvm::ConstantPointerNull::get(builder().getInt64Ty()->getPointerTo())); auto* nullCheckBlock = llvm::BasicBlock::Create( module_.getContext(), "row_field_null_check", builder().GetInsertBlock()->getParent()); auto* rawDoneBlock = llvm::BasicBlock::Create( - module_.getContext(), "row_field_null_done", builder().GetInsertBlock()->getParent()); + module_.getContext(), "row_field_raw_null_done", builder().GetInsertBlock()->getParent()); builder().CreateCondBr(hasRawNulls, nullCheckBlock, rawDoneBlock); auto* noNullsEnd = builder().GetInsertBlock(); @@ -656,7 +710,25 @@ llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( auto* fastNull = builder().CreatePHI(builder().getInt1Ty(), 2, "row_field_raw_is_null"); fastNull->addIncoming(builder().getFalse(), noNullsEnd); fastNull->addIncoming(isNull, nullCheckEnd); - return fastNull; + builder().CreateBr(doneBlock); + auto* rawEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(helperPathBlock); + auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); + auto* helperNull = builder().CreateICmpNE( + builder().CreateCall( + module_.getFunction("jit_GetDecodedRowFieldIsNull"), + {decodedVector, row, builder().getInt32(field)}), + builder().getInt8(0)); + builder().CreateBr(doneBlock); + auto* helperEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(doneBlock); + auto* result = + builder().CreatePHI(builder().getInt1Ty(), 2, "row_field_is_null"); + result->addIncoming(fastNull, rawEnd); + result->addIncoming(helperNull, helperEnd); + return result; } auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); return builder().CreateICmpNE( @@ -764,7 +836,9 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalAvg" : "jit_HashAggrExtractFinalDecimalAvg"; auto* longDecimal = builder().getInt8( - slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + slot.desc.auxPrecision > bytedance::bolt::ShortDecimalType::kMaxPrecision + ? 1 + : 0); auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), @@ -774,6 +848,8 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( builder().getInt32(slot.offset), builder().getInt32(slot.desc.precision), builder().getInt32(slot.desc.scale), + builder().getInt32(slot.desc.auxPrecision), + builder().getInt32(slot.desc.auxScale), longDecimal}); } diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index fbcec7bdf..f78d32ea8 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -120,10 +120,10 @@ void compileDecimalAvgAddIntermediateResults( } bool canCompileDecimalAvgExtract(const HashAggrJitSlot&, bool partialOutput) { - // Only the partial (extractAccumulators) path is JIT-supported for decimal - // avg. Final avg needs the full per-aggregate rescale logic and stays on - // the non-JIT path. - return partialOutput; + // Both partial (extractAccumulators) and final extract go through runtime + // helpers. Final decimal avg keeps the divide/rescale logic in the helper to + // avoid duplicating Spark decimal semantics in LLVM IR. + return true; } void compileDecimalAvgExtract( diff --git a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp index 9991ab16a..f18fff1d9 100644 --- a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp +++ b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp @@ -9,11 +9,14 @@ // adjustment and write the result into the output vector. Resolved by the ORC // JIT through the process global symbol table, so they only need default // visibility and to be linked into the host process. These mirror the -// per-aggregate computeFinalValue logic but depend only on the type-layer -// DecimalUtil, so they live next to the other HashAggr runtime helpers. +// per-aggregate computeFinalValue logic but depend only on shared decimal +// utility helpers, so they live next to the other HashAggr runtime helpers. #ifdef ENABLE_BOLT_JIT +#include + +#include "bolt/functions/sparksql/DecimalUtil.h" #include "bolt/jit/aggregation/HashAggrJitTypes.h" #include "bolt/type/DecimalUtil.h" #include "bolt/vector/ComplexVector.h" @@ -41,6 +44,71 @@ bytedance::bolt::int128_t jitDecimalSumComputeFinal( return sum; } +uint8_t jitDecimalAvgComputeRescaleFactor( + uint8_t fromScale, + uint8_t toScale, + uint8_t resultScale) { + return resultScale - fromScale + toScale; +} + +std::pair jitDecimalAvgComputeResultPrecisionScale( + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale) { + uint8_t intDigits = aPrecision - aScale + bScale; + uint8_t scale = std::max(6, aScale + bPrecision + 1); + uint8_t precision = intDigits + scale; + return bytedance::bolt::functions::sparksql::DecimalUtil::adjustPrecisionScale( + precision, scale); +} + +template +std::optional jitDecimalAvgComputeFinal( + const bytedance::bolt::jit::JitDecimalAvgState* state, + int32_t sumPrecision, + int32_t sumScale, + int32_t resultPrecision, + int32_t resultScale) { + auto adjustedSum = bytedance::bolt::DecimalUtil::adjustSumForOverflow( + state->sum, state->overflow); + if (!adjustedSum.has_value()) { + return std::nullopt; + } + + constexpr uint8_t kCountPrecision = 20; + constexpr uint8_t kCountScale = 0; + const auto [avgPrecision, avgScale] = jitDecimalAvgComputeResultPrecisionScale( + static_cast(sumPrecision), + static_cast(sumScale), + kCountPrecision, + kCountScale); + const auto sumRescale = jitDecimalAvgComputeRescaleFactor( + static_cast(sumScale), kCountScale, avgScale); + + bytedance::bolt::int128_t avg = 0; + bool overflow = false; + bytedance::bolt::functions::sparksql::DecimalUtil:: + divideWithRoundUp( + avg, adjustedSum.value(), state->count, sumRescale, overflow); + if (overflow) { + return std::nullopt; + } + + TResult rescaledValue; + const auto status = bytedance::bolt::DecimalUtil:: + rescaleWithRoundUp( + avg, + avgPrecision, + avgScale, + static_cast(resultPrecision), + static_cast(resultScale), + rescaledValue); + return status.ok() ? std::optional(rescaledValue) : std::nullopt; +} + } // namespace extern "C" { @@ -109,8 +177,6 @@ jit_HashAggrExtractPartialDecimalSum( // Partial decimal avg extract: write row(sum:decimal, count:bigint). // Overflow during sum adjustment -> sum child set to null, count kept. -// (Final decimal avg extract stays on the non-JIT path; the rescale logic is -// too coupled to per-aggregate precision metadata.) __attribute__((__visibility__("default"))) void jit_HashAggrExtractPartialDecimalAvg( char* vector, @@ -119,6 +185,8 @@ jit_HashAggrExtractPartialDecimalAvg( int32_t offset, int32_t /*precision*/, int32_t /*scale*/, + int32_t /*resultPrecision*/, + int32_t /*resultScale*/, int8_t /*longDecimal*/) { auto* state = reinterpret_cast(group + offset); @@ -139,18 +207,52 @@ jit_HashAggrExtractPartialDecimalAvg( } } -// Final decimal avg extract is intentionally not implemented in JIT; the -// declaration exists so the JIT module link succeeds, but it is never called -// because canExtract returns false for the final (non-partial) output. +// Final decimal avg extract: write FlatVector. Null when +// the group is empty (all inputs null) or any overflow/rescale step fails. __attribute__((__visibility__("default"))) void jit_HashAggrExtractFinalDecimalAvg( - char* /*vector*/, - int32_t /*row*/, - char* /*group*/, - int32_t /*offset*/, - int32_t /*precision*/, - int32_t /*scale*/, - int8_t /*longDecimal*/) {} + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t resultPrecision, + int32_t resultScale, + int8_t longDecimal) { + auto* state = + reinterpret_cast(group + offset); + if (longDecimal) { + auto* flat = reinterpret_cast(vector) + ->as>(); + if (state->count == 0) { + flat->setNull(row, true); + return; + } + auto result = jitDecimalAvgComputeFinal( + state, precision, scale, resultPrecision, resultScale); + if (result.has_value()) { + flat->set(row, result.value()); + } else { + flat->setNull(row, true); + } + return; + } + + auto* flat = reinterpret_cast(vector) + ->as>(); + if (state->count == 0) { + flat->setNull(row, true); + return; + } + auto result = + jitDecimalAvgComputeFinal(state, precision, scale, resultPrecision, resultScale); + if (result.has_value()) { + flat->set(row, result.value()); + } else { + flat->setNull(row, true); + } +} } // extern "C" From ffd029c60e0cf2a872e2eeed54eb26baae42476e Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 17:14:45 +0800 Subject: [PATCH 42/98] fix benchmark crash by skip decimal avg final extract --- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 02015d066..244a1e031 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -55,24 +55,27 @@ class HashAggrJitBenchmark : public VectorTestBase { counts.push_back(fmt::format("count(c{})", i + 1)); } - addCase(name + "_sum", rows, sums, AggregationPlanKind::PartialFinal); - addCase(name + "_avg", rows, avgs, AggregationPlanKind::PartialFinal); - addCase(name + "_min", rows, mins, AggregationPlanKind::PartialFinal); - addCase(name + "_count", rows, counts, AggregationPlanKind::PartialFinal); + addCase(name + "_sum", rows, sums); + addCase(name + "_avg", rows, avgs); + addCase(name + "_min", rows, mins); + addCase(name + "_count", rows, counts); + addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); } void addDecimalBenchmark(const std::string& name, int32_t width) { auto rows = makeDecimalRows(width); std::vector sums; std::vector avgs; - sums.reserve(width); - avgs.reserve(width); for (auto i = 0; i < width; ++i) { sums.push_back(fmt::format("spark_sum(c{})", i + 1)); avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); } - addCase(name + "_decimal_sum", rows, sums, AggregationPlanKind::PartialFinal); - addCase(name + "_decimal_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_decimal_sum", rows, sums); + addCase(name + "_decimal_avg", rows, avgs); } void addFloatingPointMinMaxBenchmark(const std::string& name, int32_t width) { @@ -298,10 +301,7 @@ int main(int argc, char** argv) { benchmark.addHighCardinalityMergeBenchmark("width8_high_card", 8); benchmark.addHighCardinalityMergeBenchmark("width16_high_card", 16); benchmark.addHighCardinalityMergeBenchmark("width32_high_card", 32); - benchmark.addDecimalBenchmark("width4", 4); benchmark.addDecimalBenchmark("width8", 8); - benchmark.addDecimalBenchmark("width16", 16); - benchmark.addDecimalBenchmark("width32", 32); benchmark.addFloatingPointMinMaxBenchmark("width8", 8); benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); From 284f2b63dfa8bc351e964a1de9fe5405d46df06e Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 17:28:51 +0800 Subject: [PATCH 43/98] add more cases in benchmark --- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 41 ++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 244a1e031..cdaccb095 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -88,6 +88,16 @@ class HashAggrJitBenchmark : public VectorTestBase { } addCase(name + "_double_min", rows, mins); addCase(name + "_double_max", rows, maxs); + addCase( + name + "_merge_double_min", + rows, + mins, + AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_double_max", + rows, + maxs, + AggregationPlanKind::PartialFinal); } void addHighCardinalityExtractBenchmark(const std::string& name, int32_t width) { @@ -100,8 +110,8 @@ class HashAggrJitBenchmark : public VectorTestBase { avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); sums.push_back(fmt::format("spark_sum(c{})", i + 1)); } - addCase(name + "_partial_avg_extract", rows, avgs, AggregationPlanKind::Partial); - addCase(name + "_partial_sum_extract", rows, sums, AggregationPlanKind::Partial); + addCase(name + "_partial_avg", rows, avgs, AggregationPlanKind::Partial); + addCase(name + "_partial_sum", rows, sums, AggregationPlanKind::Partial); } void addHighCardinalityMergeBenchmark(const std::string& name, int32_t width) { @@ -121,10 +131,16 @@ class HashAggrJitBenchmark : public VectorTestBase { counts.push_back(fmt::format("count(c{})", i + 1)); } - addCase(name + "_sum", rows, sums, AggregationPlanKind::PartialFinal); - addCase(name + "_avg", rows, avgs, AggregationPlanKind::PartialFinal); - addCase(name + "_min", rows, mins, AggregationPlanKind::PartialFinal); - addCase(name + "_count", rows, counts, AggregationPlanKind::PartialFinal); + addCase(name + "_sum", rows, sums); + addCase(name + "_avg", rows, avgs); + addCase(name + "_min", rows, mins); + addCase(name + "_count", rows, counts); + + addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); } private: @@ -297,13 +313,26 @@ int main(int argc, char** argv) { benchmark.addBenchmark("width8", 8); benchmark.addBenchmark("width16", 16); benchmark.addBenchmark("width32", 32); + benchmark.addHighCardinalityMergeBenchmark("width4_high_card", 4); benchmark.addHighCardinalityMergeBenchmark("width8_high_card", 8); benchmark.addHighCardinalityMergeBenchmark("width16_high_card", 16); benchmark.addHighCardinalityMergeBenchmark("width32_high_card", 32); + + benchmark.addDecimalBenchmark("width4", 4); benchmark.addDecimalBenchmark("width8", 8); + benchmark.addDecimalBenchmark("width16", 16); + benchmark.addDecimalBenchmark("width32", 32); + + benchmark.addFloatingPointMinMaxBenchmark("width4", 4); benchmark.addFloatingPointMinMaxBenchmark("width8", 8); + benchmark.addFloatingPointMinMaxBenchmark("width16", 16); + benchmark.addFloatingPointMinMaxBenchmark("width32", 32); + + benchmark.addHighCardinalityExtractBenchmark("width4_high_card", 4); benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); + benchmark.addHighCardinalityExtractBenchmark("width16_high_card", 16); + benchmark.addHighCardinalityExtractBenchmark("width32_high_card", 32); folly::runBenchmarks(); return 0; From b2fbef35999e84d5f32cd8e86cbb0b8d21a721f4 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 17:39:15 +0800 Subject: [PATCH 44/98] fix final extract crash in decimal avg --- bolt/exec/GroupingSet.cpp | 12 +- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 4 +- doc/hashaggr-jit-benchmark.md | 232 ++++++++++++------ 3 files changed, 166 insertions(+), 82 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 92ec2a533..e30c5b5bd 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -151,7 +151,8 @@ void fillHashAggrJitRowFieldInputs( // caller must fall back to the non-JIT extract path. bool fillHashAggrJitPartialAvgOutput( jit::HashAggrJitOutput& output, - BaseVector* vector) { + BaseVector* vector, + const jit::HashAggrJitSlot& slot) { auto* rowVector = vector->asUnchecked(); if (rowVector->childrenSize() < 2) { return false; @@ -162,8 +163,11 @@ bool fillHashAggrJitPartialAvgOutput( countVector->encoding() != VectorEncoding::Simple::FLAT) { return false; } - output.rowField0Values = - sumVector->asUnchecked>()->mutableRawValues(); + output.rowField0Values = slot.desc.decimal + ? static_cast( + sumVector->asUnchecked>()->mutableRawValues()) + : static_cast( + sumVector->asUnchecked>()->mutableRawValues()); output.rowField0Nulls = sumVector->mutableRawNulls(); output.rowField1Values = countVector->asUnchecked>()->mutableRawValues(); @@ -1234,7 +1238,7 @@ void GroupingSet::runHashAggrJitExtractChunks( slot.desc.kind == jit::HashAggrJitKind::Avg) { hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); if (!fillHashAggrJitPartialAvgOutput( - hashAggrJitOutputs_[slotIndex], aggregateVector.get())) { + hashAggrJitOutputs_[slotIndex], aggregateVector.get(), slot)) { canRunChunk = false; skipReason = "partial avg row fields are not flat"; break; diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index cdaccb095..fc9f2b511 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -74,8 +74,8 @@ class HashAggrJitBenchmark : public VectorTestBase { sums.push_back(fmt::format("spark_sum(c{})", i + 1)); avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); } - addCase(name + "_decimal_sum", rows, sums); - addCase(name + "_decimal_avg", rows, avgs); + addCase(name + "_decimal_sum", rows, sums, AggregationPlanKind::PartialFinal); + addCase(name + "_decimal_avg", rows, avgs, AggregationPlanKind::PartialFinal); } void addFloatingPointMinMaxBenchmark(const std::string& name, int32_t width) { diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 8f0fb5ff2..6167c8ec3 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -1196,119 +1196,199 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 ``` ============================================================================ -width4_sum_nojit 2.62ms 382.13 -width4_sum_jit 2.37ms 422.76 +[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s +============================================================================ ---------------------------------------------------------------------------- -width4_avg_nojit 3.14ms 318.68 -width4_avg_jit 2.61ms 383.32 ---------------------------------------------------------------------------- -width4_min_nojit 2.42ms 413.02 -width4_min_jit 2.41ms 415.43 ---------------------------------------------------------------------------- -width4_count_nojit 2.42ms 413.14 -width4_count_jit 1.94ms 515.97 ---------------------------------------------------------------------------- -width4_merge_sum_nojit 3.59ms 278.76 -width4_merge_sum_jit 3.36ms 297.54 ---------------------------------------------------------------------------- -width4_merge_avg_nojit 4.73ms 211.27 -width4_merge_avg_jit 3.68ms 271.39 ---------------------------------------------------------------------------- -width4_merge_min_nojit 3.49ms 286.45 -width4_merge_min_jit 3.42ms 292.31 ---------------------------------------------------------------------------- -width4_merge_count_nojit 3.41ms 293.08 -width4_merge_count_jit 2.89ms 345.61 ---------------------------------------------------------------------------- -width8_sum_nojit 4.77ms 209.45 -width8_sum_jit 3.39ms 294.95 +width8_sum_nojit 4.75ms 210.73 +width8_sum_jit 3.49ms 286.89 ---------------------------------------------------------------------------- -width8_avg_nojit 5.33ms 187.66 -width8_avg_jit 4.08ms 244.82 +width8_avg_nojit 5.36ms 186.61 +width8_avg_jit 4.17ms 239.53 ---------------------------------------------------------------------------- -width8_min_nojit 3.71ms 269.36 -width8_min_jit 3.56ms 280.86 +width8_min_nojit 3.75ms 266.36 +width8_min_jit 3.47ms 288.20 ---------------------------------------------------------------------------- -width8_count_nojit 4.27ms 233.96 -width8_count_jit 2.33ms 429.47 +width8_count_nojit 4.39ms 227.95 +width8_count_jit 2.34ms 427.14 ---------------------------------------------------------------------------- -width8_merge_sum_nojit 6.26ms 159.83 -width8_merge_sum_jit 4.77ms 209.84 +width8_merge_sum_nojit 6.22ms 160.77 +width8_merge_sum_jit 4.75ms 210.69 ---------------------------------------------------------------------------- -width8_merge_avg_nojit 7.37ms 135.72 -width8_merge_avg_jit 5.57ms 179.63 +width8_merge_avg_nojit 7.59ms 131.72 +width8_merge_avg_jit 5.84ms 171.16 ---------------------------------------------------------------------------- -width8_merge_min_nojit 5.27ms 189.67 -width8_merge_min_jit 4.80ms 208.55 +width8_merge_min_nojit 5.64ms 177.34 +width8_merge_min_jit 4.89ms 204.57 ---------------------------------------------------------------------------- -width8_merge_count_nojit 5.81ms 172.20 -width8_merge_count_jit 3.49ms 286.31 +width8_merge_count_nojit 6.08ms 164.52 +width8_merge_count_jit 3.68ms 271.96 ---------------------------------------------------------------------------- -width16_sum_nojit 8.69ms 115.13 -width16_sum_jit 6.10ms 163.95 +width16_sum_nojit 8.94ms 111.84 +width16_sum_jit 6.03ms 165.78 ---------------------------------------------------------------------------- -width16_avg_nojit 10.47ms 95.49 -width16_avg_jit 7.26ms 137.82 +width16_avg_nojit 10.71ms 93.37 +width16_avg_jit 7.39ms 135.30 ---------------------------------------------------------------------------- -width16_min_nojit 7.67ms 130.38 -width16_min_jit 6.29ms 158.91 +width16_min_nojit 7.62ms 131.30 +width16_min_jit 6.22ms 160.80 ---------------------------------------------------------------------------- -width16_count_nojit 7.73ms 129.35 -width16_count_jit 3.54ms 282.54 +width16_count_nojit 7.87ms 127.05 +width16_count_jit 3.62ms 275.88 +---------------------------------------------------------------------------- +width16_merge_sum_nojit 11.47ms 87.15 +width16_merge_sum_jit 7.79ms 128.42 +---------------------------------------------------------------------------- +width16_merge_avg_nojit 14.40ms 69.45 +width16_merge_avg_jit 10.53ms 94.95 +---------------------------------------------------------------------------- +width16_merge_min_nojit 10.14ms 98.61 +width16_merge_min_jit 7.73ms 129.41 +---------------------------------------------------------------------------- +width16_merge_count_nojit 9.62ms 103.94 +width16_merge_count_jit 5.16ms 193.66 +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +width8_high_card_sum_nojit 40.01ms 25.00 +width8_high_card_sum_jit 31.61ms 31.63 +---------------------------------------------------------------------------- +width8_high_card_avg_nojit 46.37ms 21.56 +width8_high_card_avg_jit 38.18ms 26.19 +---------------------------------------------------------------------------- +width8_high_card_min_nojit 37.29ms 26.82 +width8_high_card_min_jit 30.62ms 32.66 +---------------------------------------------------------------------------- +width8_high_card_count_nojit 34.43ms 29.04 +width8_high_card_count_jit 30.06ms 33.27 +---------------------------------------------------------------------------- +width8_high_card_merge_sum_nojit 61.33ms 16.31 +width8_high_card_merge_sum_jit 52.05ms 19.21 +---------------------------------------------------------------------------- +width8_high_card_merge_avg_nojit 94.24ms 10.61 +width8_high_card_merge_avg_jit 78.64ms 12.72 +---------------------------------------------------------------------------- +width8_high_card_merge_min_nojit 62.70ms 15.95 +width8_high_card_merge_min_jit 51.41ms 19.45 +---------------------------------------------------------------------------- +width8_high_card_merge_count_nojit 57.81ms 17.30 +width8_high_card_merge_count_jit 53.43ms 18.72 +---------------------------------------------------------------------------- +width16_high_card_sum_nojit 70.22ms 14.24 +width16_high_card_sum_jit 55.12ms 18.14 +---------------------------------------------------------------------------- +width16_high_card_avg_nojit 84.58ms 11.82 +width16_high_card_avg_jit 66.30ms 15.08 +---------------------------------------------------------------------------- +width16_high_card_min_nojit 67.12ms 14.90 +width16_high_card_min_jit 54.09ms 18.49 +---------------------------------------------------------------------------- +width16_high_card_count_nojit 59.38ms 16.84 +width16_high_card_count_jit 47.88ms 20.89 +---------------------------------------------------------------------------- +width16_high_card_merge_sum_nojit 113.95ms 8.78 +width16_high_card_merge_sum_jit 93.46ms 10.70 +---------------------------------------------------------------------------- +width16_high_card_merge_avg_nojit 157.39ms 6.35 +width16_high_card_merge_avg_jit 137.35ms 7.28 +---------------------------------------------------------------------------- +width16_high_card_merge_min_nojit 110.10ms 9.08 +width16_high_card_merge_min_jit 91.93ms 10.88 +---------------------------------------------------------------------------- +width16_high_card_merge_count_nojit 103.19ms 9.69 +width16_high_card_merge_count_jit 88.89ms 11.25 +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +width8_decimal_sum_nojit 12.15ms 82.32 +width8_decimal_sum_jit 8.16ms 122.56 +---------------------------------------------------------------------------- +width8_decimal_avg_nojit 16.20ms 61.74 +width8_decimal_avg_jit 9.55ms 104.73 +---------------------------------------------------------------------------- +width16_decimal_sum_nojit 23.55ms 42.46 +width16_decimal_sum_jit 16.41ms 60.93 +---------------------------------------------------------------------------- +width16_decimal_avg_nojit 32.43ms 30.83 +width16_decimal_avg_jit 19.10ms 52.36 +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +---------------------------------------------------------------------------- +width8_double_min_nojit 4.97ms 201.41 +width8_double_min_jit 4.19ms 238.53 +---------------------------------------------------------------------------- +width8_double_max_nojit 4.22ms 236.76 +width8_double_max_jit 3.89ms 257.11 ---------------------------------------------------------------------------- -width16_merge_sum_nojit 11.25ms 88.91 -width16_merge_sum_jit 7.42ms 134.78 +width8_merge_double_min_nojit 6.96ms 143.68 +width8_merge_double_min_jit 5.71ms 175.20 ---------------------------------------------------------------------------- -width16_merge_avg_nojit 14.00ms 71.45 -width16_merge_avg_jit 9.94ms 100.61 +width8_merge_double_max_nojit 6.01ms 166.45 +width8_merge_double_max_jit 5.10ms 195.97 ---------------------------------------------------------------------------- -width16_merge_min_nojit 10.03ms 99.69 -width16_merge_min_jit 8.16ms 122.54 +width16_double_min_nojit 9.93ms 100.71 +width16_double_min_jit 7.81ms 128.06 ---------------------------------------------------------------------------- -width16_merge_count_nojit 9.83ms 101.69 -width16_merge_count_jit 4.95ms 202.04 +width16_double_max_nojit 8.75ms 114.27 +width16_double_max_jit 7.18ms 139.30 ---------------------------------------------------------------------------- -width32_sum_nojit 17.00ms 58.83 -width32_sum_jit 12.28ms 81.46 +width16_merge_double_min_nojit 12.39ms 80.74 +width16_merge_double_min_jit 9.53ms 104.88 ---------------------------------------------------------------------------- -width32_avg_nojit 19.46ms 51.38 -width32_avg_jit 15.26ms 65.52 +width16_merge_double_max_nojit 10.93ms 91.50 +width16_merge_double_max_jit 9.04ms 110.68 ---------------------------------------------------------------------------- -width32_min_nojit 15.13ms 66.10 -width32_min_jit 12.11ms 82.55 ---------------------------------------------------------------------------- -width32_count_nojit 15.48ms 64.60 -width32_count_jit 6.94ms 144.09 ---------------------------------------------------------------------------- -width32_merge_sum_nojit 22.55ms 44.34 -width32_merge_sum_jit 16.11ms 62.08 ---------------------------------------------------------------------------- -width32_merge_avg_nojit 27.31ms 36.62 -width32_merge_avg_jit 21.04ms 47.52 ---------------------------------------------------------------------------- -width32_merge_min_nojit 20.48ms 48.82 -width32_merge_min_jit 16.56ms 60.38 ---------------------------------------------------------------------------- -width32_merge_count_nojit 19.85ms 50.39 -width32_merge_count_jit 10.45ms 95.69 ---------------------------------------------------------------------------- -width8_decimal_sum_nojit 11.78ms 84.91 -width8_decimal_sum_jit 8.10ms 123.47 +width8_high_card_partial_avg_nojit 56.15ms 17.81 +width8_high_card_partial_avg_jit 61.69ms 16.21 ---------------------------------------------------------------------------- -width8_decimal_avg_nojit 16.08ms 62.21 -width8_decimal_avg_jit 12.94ms 77.25 +width8_high_card_partial_sum_nojit 25.49ms 39.23 +width8_high_card_partial_sum_jit 22.00ms 45.46 ---------------------------------------------------------------------------- -width8_double_min_nojit 5.03ms 198.86 -width8_double_min_jit 4.16ms 240.33 +width16_high_card_partial_avg_nojit 99.27ms 10.07 +width16_high_card_partial_avg_jit 114.96ms 8.70 ---------------------------------------------------------------------------- -width8_double_max_nojit 4.17ms 239.71 -width8_double_max_jit 3.94ms 253.99 +width16_high_card_partial_sum_nojit 48.90ms 20.45 +width16_high_card_partial_sum_jit 41.31ms 24.21 ---------------------------------------------------------------------------- -width8_high_card_partial_avg_extract_nojit 57.83ms 17.29 -width8_high_card_partial_avg_extract_jit 61.68ms 16.21 ---------------------------------------------------------------------------- -width8_high_card_partial_sum_extract_nojit 25.24ms 39.62 -width8_high_card_partial_sum_extract_jit 22.27ms 44.90 ---------------------------------------------------------------------------- ``` From 25d83d63507cabf9a25d7544c6f13a7f1a07d533 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 18:00:25 +0800 Subject: [PATCH 45/98] add more bench cases --- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index fc9f2b511..d1e5389ca 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -43,25 +43,30 @@ class HashAggrJitBenchmark : public VectorTestBase { std::vector sums; std::vector avgs; std::vector mins; + std::vector maxs; std::vector counts; sums.reserve(width); avgs.reserve(width); mins.reserve(width); + maxs.reserve(width); counts.reserve(width); for (auto i = 0; i < width; ++i) { sums.push_back(fmt::format("spark_sum(c{})", i + 1)); avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); mins.push_back(fmt::format("min(c{})", i + 1)); + maxs.push_back(fmt::format("max(c{})", i + 1)); counts.push_back(fmt::format("count(c{})", i + 1)); } - addCase(name + "_sum", rows, sums); - addCase(name + "_avg", rows, avgs); - addCase(name + "_min", rows, mins); - addCase(name + "_count", rows, counts); + // addCase(name + "_sum", rows, sums); + // addCase(name + "_avg", rows, avgs); + // addCase(name + "_min", rows, mins); + // addCase(name + "_max", rows, maxs); + // addCase(name + "_count", rows, counts); addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_max", rows, maxs, AggregationPlanKind::PartialFinal); addCase( name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); } @@ -74,8 +79,16 @@ class HashAggrJitBenchmark : public VectorTestBase { sums.push_back(fmt::format("spark_sum(c{})", i + 1)); avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); } - addCase(name + "_decimal_sum", rows, sums, AggregationPlanKind::PartialFinal); - addCase(name + "_decimal_avg", rows, avgs, AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_decimal_sum", + rows, + sums, + AggregationPlanKind::PartialFinal); + addCase( + name + "_merge_decimal_avg", + rows, + avgs, + AggregationPlanKind::PartialFinal); } void addFloatingPointMinMaxBenchmark(const std::string& name, int32_t width) { @@ -86,8 +99,8 @@ class HashAggrJitBenchmark : public VectorTestBase { mins.push_back(fmt::format("min(c{})", i + 1)); maxs.push_back(fmt::format("max(c{})", i + 1)); } - addCase(name + "_double_min", rows, mins); - addCase(name + "_double_max", rows, maxs); + // addCase(name + "_double_min", rows, mins); + // addCase(name + "_double_max", rows, maxs); addCase( name + "_merge_double_min", rows, @@ -119,26 +132,31 @@ class HashAggrJitBenchmark : public VectorTestBase { std::vector sums; std::vector avgs; std::vector mins; + std::vector maxs; std::vector counts; sums.reserve(width); avgs.reserve(width); mins.reserve(width); + maxs.reserve(width); counts.reserve(width); for (auto i = 0; i < width; ++i) { sums.push_back(fmt::format("spark_sum(c{})", i + 1)); avgs.push_back(fmt::format("spark_avg(c{})", i + 1)); mins.push_back(fmt::format("min(c{})", i + 1)); + maxs.push_back(fmt::format("max(c{})", i + 1)); counts.push_back(fmt::format("count(c{})", i + 1)); } - addCase(name + "_sum", rows, sums); - addCase(name + "_avg", rows, avgs); - addCase(name + "_min", rows, mins); - addCase(name + "_count", rows, counts); + // addCase(name + "_sum", rows, sums); + // addCase(name + "_avg", rows, avgs); + // addCase(name + "_min", rows, mins); + // addCase(name + "_max", rows, maxs); + // addCase(name + "_count", rows, counts); addCase(name + "_merge_sum", rows, sums, AggregationPlanKind::PartialFinal); addCase(name + "_merge_avg", rows, avgs, AggregationPlanKind::PartialFinal); addCase(name + "_merge_min", rows, mins, AggregationPlanKind::PartialFinal); + addCase(name + "_merge_max", rows, maxs, AggregationPlanKind::PartialFinal); addCase( name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); } @@ -329,10 +347,10 @@ int main(int argc, char** argv) { benchmark.addFloatingPointMinMaxBenchmark("width16", 16); benchmark.addFloatingPointMinMaxBenchmark("width32", 32); - benchmark.addHighCardinalityExtractBenchmark("width4_high_card", 4); - benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); - benchmark.addHighCardinalityExtractBenchmark("width16_high_card", 16); - benchmark.addHighCardinalityExtractBenchmark("width32_high_card", 32); + // benchmark.addHighCardinalityExtractBenchmark("width4_high_card", 4); + // benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); + // benchmark.addHighCardinalityExtractBenchmark("width16_high_card", 16); + // benchmark.addHighCardinalityExtractBenchmark("width32_high_card", 32); folly::runBenchmarks(); return 0; From 8040e5ad38dcec5789ce1649aa8d6fb0a69c0efd Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 19:30:21 +0800 Subject: [PATCH 46/98] =?UTF-8?q?decimal=20sum=20merge=20=E8=BE=93?= =?UTF-8?q?=E5=85=A5=20row-field=20=E5=BF=AB=E8=B7=AF=E5=BE=84=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bolt/exec/GroupingSet.cpp | 39 +++- bolt/jit/aggregation/HashAggrJit.cpp | 62 +++++ bolt/jit/aggregation/HashAggrJit.h | 7 + bolt/jit/aggregation/HashAggrJitTypes.h | 8 + bolt/jit/aggregation/ops/AvgOps.cpp | 21 +- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 2 +- doc/hashaggr-jit-benchmark.md | 260 ++++++++++----------- 7 files changed, 248 insertions(+), 151 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index e30c5b5bd..850a9bd2f 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -120,7 +120,19 @@ void fillHashAggrJitRowFieldInputs( jit::HashAggrJitDecodedInput& input, const DecodedVector& decoded, const jit::HashAggrJitSlot& slot) { - if (!slot.desc.mergeInput || slot.desc.kind != jit::HashAggrJitKind::Avg) { + // The raw row-field fast path applies to merge inputs whose intermediate + // representation is ROW(field0, field1): avg's ROW(sum, count) and decimal + // sum's ROW(sum, isEmpty). For both, field0 is the running sum read on the + // hot merge path; populating its raw pointer lets generated IR load it + // directly instead of calling the per-row jit_GetDecodedRowField* helper + // (which rebuilds a field DecodedVector on every call). + if (!slot.desc.mergeInput) { + return; + } + const bool isAvg = slot.desc.kind == jit::HashAggrJitKind::Avg; + const bool isDecimalSum = + slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal; + if (!isAvg && !isDecimalSum) { return; } const auto* base = decoded.base(); @@ -132,17 +144,30 @@ void fillHashAggrJitRowFieldInputs( return; } const auto& sumVector = rowVector->childAt(0); - const auto& countVector = rowVector->childAt(1); - if (sumVector->encoding() != VectorEncoding::Simple::FLAT || - countVector->encoding() != VectorEncoding::Simple::FLAT) { + if (sumVector->encoding() != VectorEncoding::Simple::FLAT) { return; } input.rowField0Values = hashAggrJitRawInputValues(sumVector.get(), slot.desc.inputKind); input.rowField0Nulls = sumVector->rawNulls(); - input.rowField1Values = - hashAggrJitRawInputValues(countVector.get(), jit::HashAggrJitValueKind::Int64); - input.rowField1Nulls = countVector->rawNulls(); + // field1 differs by aggregate: avg's count is a flat int64 scalar; decimal + // sum's isEmpty is a bit-packed bool whose rawValues() is the bit-word + // buffer consumed by loadDecodedRowFieldBool's bit-read fast path. + const auto& field1Vector = rowVector->childAt(1); + if (field1Vector->encoding() != VectorEncoding::Simple::FLAT) { + return; + } + if (isAvg) { + input.rowField1Values = hashAggrJitRawInputValues( + field1Vector.get(), jit::HashAggrJitValueKind::Int64); + input.rowField1Nulls = field1Vector->rawNulls(); + } else { + // isEmpty is bit-packed bool: valuesAsVoid() exposes the underlying + // bit-word buffer (rawValues() throws for bool). loadDecodedRowFieldBool + // bit-reads it directly. + input.rowField1Values = field1Vector->valuesAsVoid(); + input.rowField1Nulls = field1Vector->rawNulls(); + } } // Fills the raw flat sum/count field pointers for a partial avg ROW output. diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 9a89731ef..930390e23 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -738,6 +738,68 @@ llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( builder().getInt8(0)); } +llvm::Value* HashAggrJitCodegen::loadDecodedRowFieldBool( + llvm::Value* decoded, + llvm::Value* row, + int32_t field) const { + // Bool ROW fields are bit-packed, so the raw values pointer addresses an + // i64 word array indexed by bit, not a byte-per-value buffer. When the raw + // pointer is populated (merge fast path), read the bit directly; otherwise + // fall back to the helper which decodes the field per row. + auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( + builder(), decoded, field, false); + auto* hasRawValues = builder().CreateICmpNE( + rawValues, + llvm::ConstantPointerNull::get( + llvm::PointerType::get(builder().getContext(), 0))); + auto* fastBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_bool_raw_load", + builder().GetInsertBlock()->getParent()); + auto* slowBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_bool_helper_load", + builder().GetInsertBlock()->getParent()); + auto* doneBlock = llvm::BasicBlock::Create( + module_.getContext(), + "row_field_bool_load_done", + builder().GetInsertBlock()->getParent()); + builder().CreateCondBr(hasRawValues, fastBlock, slowBlock); + + builder().SetInsertPoint(fastBlock); + auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); + // bit at 'index' inside the i64 word array: word = index >> 6, bit = index & + // 63; value = (words[word] >> bit) & 1. + auto* i64Ty = builder().getInt64Ty(); + auto* words = builder().CreatePointerCast(rawValues, i64Ty->getPointerTo()); + auto* index64 = builder().CreateZExt(index, i64Ty); + auto* wordIndex = builder().CreateLShr(index64, builder().getInt64(6)); + auto* bitIndex = builder().CreateAnd(index64, builder().getInt64(63)); + auto* word = builder().CreateLoad( + i64Ty, builder().CreateInBoundsGEP(i64Ty, words, wordIndex)); + auto* shifted = builder().CreateLShr(word, bitIndex); + auto* bit = builder().CreateAnd(shifted, builder().getInt64(1)); + auto* fastValue = builder().CreateTrunc(bit, builder().getInt8Ty()); + builder().CreateBr(doneBlock); + auto* fastEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(slowBlock); + auto* decodedVector = + ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); + auto* slowValue = builder().CreateCall( + module_.getFunction("jit_GetDecodedRowFieldI8"), + {decodedVector, row, builder().getInt32(field)}); + builder().CreateBr(doneBlock); + auto* slowEnd = builder().GetInsertBlock(); + + builder().SetInsertPoint(doneBlock); + auto* value = + builder().CreatePHI(builder().getInt8Ty(), 2, "row_field_bool_value"); + value->addIncoming(fastValue, fastEnd); + value->addIncoming(slowValue, slowEnd); + return value; +} + void HashAggrJitCodegen::emitFlatValue( llvm::Value* output, llvm::Value* row, diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index a6316c247..264fcf9a7 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -104,6 +104,13 @@ class HashAggrJitCodegen { llvm::Value* decoded, llvm::Value* row, int32_t field) const; + // Reads a bit-packed bool ROW field (e.g. decimal sum's isEmpty) as an i8 + // 0/1. The raw fast path bit-reads the flat bool buffer; falls back to the + // jit_GetDecodedRowFieldI8 helper when the field's raw pointer is unset. + llvm::Value* loadDecodedRowFieldBool( + llvm::Value* decoded, + llvm::Value* row, + int32_t field) const; void emitFlatValue( llvm::Value* vector, llvm::Value* row, diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index f8620f62c..490407fc1 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -2,6 +2,7 @@ #ifdef ENABLE_BOLT_JIT +#include #include #include #include @@ -17,6 +18,13 @@ namespace bytedance::bolt::jit { +// JIT-internal accumulator layout for avg. Shared between avg ops codegen and +// any runtime/helper logic that needs to reason about the in-row state layout. +struct JitAvgState { + double sum{0}; + int64_t count{0}; +}; + // JIT-internal accumulator layouts for decimal sum/avg. Shared between the JIT // codegen runtime helpers and the extract runtime helpers (which live in a // different translation unit and need DecimalUtil). diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index 33fca6860..abdc8699d 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -11,6 +11,8 @@ namespace bytedance::bolt::jit { namespace { +constexpr int32_t kAvgCountOffset = offsetof(JitAvgState, count); + void compileAvgInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -24,7 +26,7 @@ void compileAvgInitGroup( codegen.storeValue( group, codegen.builder().getInt64Ty(), - slot.offset + 8, + slot.offset + kAvgCountOffset, codegen.builder().getInt64(0)); } @@ -48,11 +50,13 @@ void compileAvgAddRawInput( slot.offset, codegen.builder().CreateFAdd(oldSum, value)); auto* oldCount = codegen.loadValue( - group, codegen.builder().getInt64Ty(), slot.offset + 8); + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset); codegen.storeValue( group, codegen.builder().getInt64Ty(), - slot.offset + 8, + slot.offset + kAvgCountOffset, codegen.builder().CreateAdd(oldCount, codegen.builder().getInt64(1))); } @@ -77,16 +81,18 @@ void compileAvgAddIntermediateResults( slot.offset, codegen.builder().CreateFAdd(oldSum, sum)); auto* oldCount = codegen.loadValue( - group, codegen.builder().getInt64Ty(), slot.offset + 8); + group, + codegen.builder().getInt64Ty(), + slot.offset + kAvgCountOffset); codegen.storeValue( group, codegen.builder().getInt64Ty(), - slot.offset + 8, + slot.offset + kAvgCountOffset, codegen.builder().CreateAdd(oldCount, count)); } bool canCompileAvgExtract(const HashAggrJitSlot& slot, bool) { - // Only double avg (sum=double@offset, count=int64@offset+8) is supported. + // Only double avg (JitAvgState) is supported. return slot.desc.accumulatorKind == HashAggrJitValueKind::Double; } @@ -97,7 +103,8 @@ void compileAvgExtract( const HashAggrJitExtractTarget& target) { auto& builder = codegen.builder(); auto* sum = codegen.loadValue(group, builder.getDoubleTy(), slot.offset); - auto* count = codegen.loadValue(group, builder.getInt64Ty(), slot.offset + 8); + auto* count = codegen.loadValue( + group, builder.getInt64Ty(), slot.offset + kAvgCountOffset); if (target.partialOutput) { // Intermediate output is row(sum:double, count:bigint). All-null group // yields (0, 0) with a non-null top-level row (isNull = 0), matching the diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index f829c9ad2..cc3e8d4e0 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -82,7 +82,7 @@ void compileDecimalSumAddIntermediateResults( codegen.module().getContext(), "sum_decimal_merge", function, continueBlock); auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); auto* incomingIsEmpty = - codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int8); + codegen.loadDecodedRowFieldBool(decoded, row, 1); auto* isNotEmpty = b.CreateICmpEQ(incomingIsEmpty, b.getInt8(0)); auto* isOverflow = b.CreateAnd(sumIsNull, isNotEmpty); b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md index 6167c8ec3..c556114cf 100644 --- a/doc/hashaggr-jit-benchmark.md +++ b/doc/hashaggr-jit-benchmark.md @@ -1195,6 +1195,7 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 ``` +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_regex="(width8|width16)" ============================================================================ [...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s ============================================================================ @@ -1203,190 +1204,110 @@ runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路 ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- +width8_merge_sum_nojit 6.36ms 157.19 +width8_merge_sum_jit 4.62ms 216.34 ---------------------------------------------------------------------------- +width8_merge_avg_nojit 7.57ms 132.09 +width8_merge_avg_jit 5.97ms 167.52 ---------------------------------------------------------------------------- +width8_merge_min_nojit 5.49ms 182.15 +width8_merge_min_jit 4.70ms 212.96 ---------------------------------------------------------------------------- -width8_sum_nojit 4.75ms 210.73 -width8_sum_jit 3.49ms 286.89 +width8_merge_max_nojit 5.59ms 178.76 +width8_merge_max_jit 4.78ms 209.31 ---------------------------------------------------------------------------- -width8_avg_nojit 5.36ms 186.61 -width8_avg_jit 4.17ms 239.53 +width8_merge_count_nojit 5.98ms 167.10 +width8_merge_count_jit 3.62ms 276.52 ---------------------------------------------------------------------------- -width8_min_nojit 3.75ms 266.36 -width8_min_jit 3.47ms 288.20 +width16_merge_sum_nojit 11.78ms 84.87 +width16_merge_sum_jit 7.96ms 125.58 ---------------------------------------------------------------------------- -width8_count_nojit 4.39ms 227.95 -width8_count_jit 2.34ms 427.14 +width16_merge_avg_nojit 15.00ms 66.69 +width16_merge_avg_jit 10.79ms 92.66 ---------------------------------------------------------------------------- -width8_merge_sum_nojit 6.22ms 160.77 -width8_merge_sum_jit 4.75ms 210.69 +width16_merge_min_nojit 10.42ms 95.97 +width16_merge_min_jit 7.78ms 128.48 ---------------------------------------------------------------------------- -width8_merge_avg_nojit 7.59ms 131.72 -width8_merge_avg_jit 5.84ms 171.16 +width16_merge_max_nojit 10.83ms 92.30 +width16_merge_max_jit 8.04ms 124.41 ---------------------------------------------------------------------------- -width8_merge_min_nojit 5.64ms 177.34 -width8_merge_min_jit 4.89ms 204.57 +width16_merge_count_nojit 9.77ms 102.32 +width16_merge_count_jit 5.12ms 195.16 ---------------------------------------------------------------------------- -width8_merge_count_nojit 6.08ms 164.52 -width8_merge_count_jit 3.68ms 271.96 ---------------------------------------------------------------------------- -width16_sum_nojit 8.94ms 111.84 -width16_sum_jit 6.03ms 165.78 ---------------------------------------------------------------------------- -width16_avg_nojit 10.71ms 93.37 -width16_avg_jit 7.39ms 135.30 ---------------------------------------------------------------------------- -width16_min_nojit 7.62ms 131.30 -width16_min_jit 6.22ms 160.80 ---------------------------------------------------------------------------- -width16_count_nojit 7.87ms 127.05 -width16_count_jit 3.62ms 275.88 ---------------------------------------------------------------------------- -width16_merge_sum_nojit 11.47ms 87.15 -width16_merge_sum_jit 7.79ms 128.42 ---------------------------------------------------------------------------- -width16_merge_avg_nojit 14.40ms 69.45 -width16_merge_avg_jit 10.53ms 94.95 ---------------------------------------------------------------------------- -width16_merge_min_nojit 10.14ms 98.61 -width16_merge_min_jit 7.73ms 129.41 ---------------------------------------------------------------------------- -width16_merge_count_nojit 9.62ms 103.94 -width16_merge_count_jit 5.16ms 193.66 ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- +width8_high_card_merge_sum_nojit 63.70ms 15.70 +width8_high_card_merge_sum_jit 52.54ms 19.03 ---------------------------------------------------------------------------- +width8_high_card_merge_avg_nojit 88.21ms 11.34 +width8_high_card_merge_avg_jit 83.18ms 12.02 ---------------------------------------------------------------------------- +width8_high_card_merge_min_nojit 61.33ms 16.30 +width8_high_card_merge_min_jit 53.34ms 18.75 ---------------------------------------------------------------------------- +width8_high_card_merge_max_nojit 63.09ms 15.85 +width8_high_card_merge_max_jit 52.99ms 18.87 ---------------------------------------------------------------------------- +width8_high_card_merge_count_nojit 60.97ms 16.40 +width8_high_card_merge_count_jit 56.22ms 17.79 ---------------------------------------------------------------------------- +width16_high_card_merge_sum_nojit 113.48ms 8.81 +width16_high_card_merge_sum_jit 90.50ms 11.05 ---------------------------------------------------------------------------- +width16_high_card_merge_avg_nojit 160.62ms 6.23 +width16_high_card_merge_avg_jit 146.38ms 6.83 ---------------------------------------------------------------------------- +width16_high_card_merge_min_nojit 116.27ms 8.60 +width16_high_card_merge_min_jit 92.11ms 10.86 ---------------------------------------------------------------------------- +width16_high_card_merge_max_nojit 113.06ms 8.84 +width16_high_card_merge_max_jit 91.86ms 10.89 ---------------------------------------------------------------------------- +width16_high_card_merge_count_nojit 100.63ms 9.94 +width16_high_card_merge_count_jit 89.21ms 11.21 ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- -width8_high_card_sum_nojit 40.01ms 25.00 -width8_high_card_sum_jit 31.61ms 31.63 ---------------------------------------------------------------------------- -width8_high_card_avg_nojit 46.37ms 21.56 -width8_high_card_avg_jit 38.18ms 26.19 ---------------------------------------------------------------------------- -width8_high_card_min_nojit 37.29ms 26.82 -width8_high_card_min_jit 30.62ms 32.66 +width8_merge_decimal_sum_nojit 19.57ms 51.09 +width8_merge_decimal_sum_jit 21.68ms 46.12 ---------------------------------------------------------------------------- -width8_high_card_count_nojit 34.43ms 29.04 -width8_high_card_count_jit 30.06ms 33.27 +width8_merge_decimal_avg_nojit 19.50ms 51.29 +width8_merge_decimal_avg_jit 13.67ms 73.15 ---------------------------------------------------------------------------- -width8_high_card_merge_sum_nojit 61.33ms 16.31 -width8_high_card_merge_sum_jit 52.05ms 19.21 +width16_merge_decimal_sum_nojit 39.46ms 25.34 +width16_merge_decimal_sum_jit 42.44ms 23.56 ---------------------------------------------------------------------------- -width8_high_card_merge_avg_nojit 94.24ms 10.61 -width8_high_card_merge_avg_jit 78.64ms 12.72 +width16_merge_decimal_avg_nojit 40.01ms 24.99 +width16_merge_decimal_avg_jit 26.90ms 37.17 ---------------------------------------------------------------------------- -width8_high_card_merge_min_nojit 62.70ms 15.95 -width8_high_card_merge_min_jit 51.41ms 19.45 ---------------------------------------------------------------------------- -width8_high_card_merge_count_nojit 57.81ms 17.30 -width8_high_card_merge_count_jit 53.43ms 18.72 ---------------------------------------------------------------------------- -width16_high_card_sum_nojit 70.22ms 14.24 -width16_high_card_sum_jit 55.12ms 18.14 ---------------------------------------------------------------------------- -width16_high_card_avg_nojit 84.58ms 11.82 -width16_high_card_avg_jit 66.30ms 15.08 ---------------------------------------------------------------------------- -width16_high_card_min_nojit 67.12ms 14.90 -width16_high_card_min_jit 54.09ms 18.49 +width8_merge_double_min_nojit 6.82ms 146.54 +width8_merge_double_min_jit 5.58ms 179.12 ---------------------------------------------------------------------------- -width16_high_card_count_nojit 59.38ms 16.84 -width16_high_card_count_jit 47.88ms 20.89 +width8_merge_double_max_nojit 5.89ms 169.65 +width8_merge_double_max_jit 5.25ms 190.33 ---------------------------------------------------------------------------- -width16_high_card_merge_sum_nojit 113.95ms 8.78 -width16_high_card_merge_sum_jit 93.46ms 10.70 +width16_merge_double_min_nojit 12.33ms 81.08 +width16_merge_double_min_jit 9.74ms 102.70 ---------------------------------------------------------------------------- -width16_high_card_merge_avg_nojit 157.39ms 6.35 -width16_high_card_merge_avg_jit 137.35ms 7.28 ----------------------------------------------------------------------------- -width16_high_card_merge_min_nojit 110.10ms 9.08 -width16_high_card_merge_min_jit 91.93ms 10.88 ----------------------------------------------------------------------------- -width16_high_card_merge_count_nojit 103.19ms 9.69 -width16_high_card_merge_count_jit 88.89ms 11.25 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_decimal_sum_nojit 12.15ms 82.32 -width8_decimal_sum_jit 8.16ms 122.56 ----------------------------------------------------------------------------- -width8_decimal_avg_nojit 16.20ms 61.74 -width8_decimal_avg_jit 9.55ms 104.73 ----------------------------------------------------------------------------- -width16_decimal_sum_nojit 23.55ms 42.46 -width16_decimal_sum_jit 16.41ms 60.93 ----------------------------------------------------------------------------- -width16_decimal_avg_nojit 32.43ms 30.83 -width16_decimal_avg_jit 19.10ms 52.36 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_double_min_nojit 4.97ms 201.41 -width8_double_min_jit 4.19ms 238.53 ----------------------------------------------------------------------------- -width8_double_max_nojit 4.22ms 236.76 -width8_double_max_jit 3.89ms 257.11 ----------------------------------------------------------------------------- -width8_merge_double_min_nojit 6.96ms 143.68 -width8_merge_double_min_jit 5.71ms 175.20 ----------------------------------------------------------------------------- -width8_merge_double_max_nojit 6.01ms 166.45 -width8_merge_double_max_jit 5.10ms 195.97 ----------------------------------------------------------------------------- -width16_double_min_nojit 9.93ms 100.71 -width16_double_min_jit 7.81ms 128.06 ----------------------------------------------------------------------------- -width16_double_max_nojit 8.75ms 114.27 -width16_double_max_jit 7.18ms 139.30 ----------------------------------------------------------------------------- -width16_merge_double_min_nojit 12.39ms 80.74 -width16_merge_double_min_jit 9.53ms 104.88 ----------------------------------------------------------------------------- -width16_merge_double_max_nojit 10.93ms 91.50 -width16_merge_double_max_jit 9.04ms 110.68 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_high_card_partial_avg_nojit 56.15ms 17.81 -width8_high_card_partial_avg_jit 61.69ms 16.21 ----------------------------------------------------------------------------- -width8_high_card_partial_sum_nojit 25.49ms 39.23 -width8_high_card_partial_sum_jit 22.00ms 45.46 ----------------------------------------------------------------------------- -width16_high_card_partial_avg_nojit 99.27ms 10.07 -width16_high_card_partial_avg_jit 114.96ms 8.70 ----------------------------------------------------------------------------- -width16_high_card_partial_sum_nojit 48.90ms 20.45 -width16_high_card_partial_sum_jit 41.31ms 24.21 +width16_merge_double_max_nojit 10.85ms 92.15 +width16_merge_double_max_jit 8.67ms 115.36 ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- ---------------------------------------------------------------------------- @@ -1449,3 +1370,70 @@ divide / overflow / precision-rescale 逻辑直接展开成 LLVM IR。 换句话说:从这一节之后,文档里关于 decimal avg 的性能讨论应默认理解为“**final extract JIT 已补齐**”的版本; 如果引用更早的数据,需要显式说明那是旧口径历史快照。 + +## 16. decimal sum merge 输入 row-field 快路径优化 + +### 16.1 背景 + +补齐 final extract 后,新增了 `*_merge_decimal_sum` / `*_merge_decimal_avg`(`PartialFinal` 口径)benchmark。 +其中 **`width16_merge_decimal_sum_jit` 反而比 non-JIT 慢**(约 42ms vs 39ms),而同口径的 decimal avg 却是 JIT 更快。 + +定位结论:瓶颈不在算术本身,而在 **final aggregation 的 merge 输入读取**。 + +- decimal sum merge 的中间结果是 `ROW(sum:decimal, isEmpty:bool)`; +- merge 热路径每行都要读 field0(sum) 与 field1(isEmpty),并判 field0 null(见 `DecimalSumOps.cpp`); +- 但 `fillHashAggrJitRowFieldInputs()` 当时**只为 avg 预填 row-field raw 指针**,decimal sum 没填; +- 于是 JIT 的 `loadDecodedRowField` / `isDecodedRowFieldNull` 全部掉到 helper slow path + (`jit_GetDecodedRowFieldI128 / I8 / IsNull`),而这些 helper **每次调用都重建一个 field 级 `DecodedVector`**; +- width16 下 slot 数翻倍,这个每行固定开销被线性放大,最终把 JIT 收益吃光。 + +decimal avg 之所以更快,正是因为它的 merge 输入早已走了 raw row-field 快路径(avg 的 `ROW(sum, count)`)。 + +### 16.2 改动 + +分两步把 decimal sum 的 merge 输入快路径补齐。 + +**第一步:field0(sum) raw fast path** + +扩展 `fillHashAggrJitRowFieldInputs()`(`bolt/exec/GroupingSet.cpp`),从“仅 avg”扩展到“avg + decimal sum”: + +- decimal sum 的 field0(sum, int128) 填充 `rowField0Values` / `rowField0Nulls`; +- 这样 JIT 读取 sum 与判 sum null 直接命中 `loadDecodedRowField` 的 raw fast path,不再每行重建 `DecodedVector`。 + +**第二步:field1(isEmpty) bit-packed bool fast path** + +field1 的 `isEmpty` 是 **bit-packed bool**,没有按字节排布的 scalar 指针,不能复用普通 fast path,因此单独新增一条按位读取的快路径: + +- 新增 `HashAggrJitCodegen::loadDecodedRowFieldBool()`(`bolt/jit/aggregation/HashAggrJit.cpp`): + 把 `rowField1Values` 视为 i64 word 数组,`word = index>>6`、`bit = index&63`,直接 `(words[word] >> bit) & 1`; + raw 指针为空时回退 `jit_GetDecodedRowFieldI8` helper; +- `DecimalSumOps.cpp` 的 merge 改用 `loadDecodedRowFieldBool` 读 `isEmpty`; +- `fillHashAggrJitRowFieldInputs()` 为 decimal sum 填 field1 的 bit 缓冲区指针, + 通过 `valuesAsVoid()` 取(注意:`FlatVector::rawValues()` 对 bool 会抛 `UNSUPPORTED`,必须用 `valuesAsVoid()`)。 + +涉及文件: + +- `bolt/exec/GroupingSet.cpp`:`fillHashAggrJitRowFieldInputs` 扩展支持 decimal sum,并填 field0/field1 指针; +- `bolt/jit/aggregation/HashAggrJit.{h,cpp}`:新增 `loadDecodedRowFieldBool`; +- `bolt/jit/aggregation/ops/DecimalSumOps.cpp`:merge 改用 bool 快路径读 `isEmpty`。 + +### 16.3 性能(PartialFinal 口径,bm_min_iters 默认) + +| case | nojit | jit(初始) | jit(+field0) | jit(+bool) | +|---|---|---|---|---| +| width8_merge_decimal_sum | 19.55ms | 慢于 nojit | 17.24ms | **16.19ms** | +| width16_merge_decimal_sum | 39.44ms | ~42ms(慢于 nojit) | 33.27ms | **31.30ms** | +| width8_merge_decimal_avg | 20.68ms | —(本就更快) | — | **14.55ms** | +| width16_merge_decimal_avg | 40.09ms | —(本就更快) | — | **28.12ms** | + +要点: + +- `width16_merge_decimal_sum_jit`:42ms(慢于 nojit)→ 33.3ms(field0 快路径)→ **31.3ms**(再加 bool 快路径),已稳定反超 nojit 的 39ms; +- decimal avg 不受负面影响,仍保持 JIT 更快。 + +### 16.4 正确性 + +- 两步快路径读到的都是与原 helper 完全相同的底层数据,仅省掉了每行的 `DecodedVector` 重建,无语义变化; +- 实现期间踩到一个坑:最初用 `FlatVector::rawValues()` 取 bit 缓冲区,运行时抛 + `BoltUserError: rawValues() for bool is not supported`,改用 `valuesAsVoid()` 后正常; +- benchmark 真实 query 在 JIT / non-JIT 双路径下均正常执行、无崩溃、无异常。 From 121d1f71601dad5f4aceb94cbeabb1eda816c1e1 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 20:35:18 +0800 Subject: [PATCH 47/98] add refactor plan doc --- hashaggr_jit_refactor_plan.md | 601 ++++++++++++++++++++++++++++++++++ 1 file changed, 601 insertions(+) create mode 100644 hashaggr_jit_refactor_plan.md diff --git a/hashaggr_jit_refactor_plan.md b/hashaggr_jit_refactor_plan.md new file mode 100644 index 000000000..8e9cf86c8 --- /dev/null +++ b/hashaggr_jit_refactor_plan.md @@ -0,0 +1,601 @@ +# Bolt Hash Aggregation JIT 框架重构落地方案 + +> 目标读者:AI/工程师,按本文档执行即可完成 `hash_aggr_jit` 分支当前框架的重构落地。 +> 适用版本:`dp/bolt @ hash_aggr_jit` 分支(基于 commit `9a65fd2` 之后)。 +> 本方案只描述 **JIT 框架层**重构,不涉及非 JIT codepath。 + +--- + +## 0. TL;DR + +把当前 `HashAggrJitDecodedInput / HashAggrJitOutput / per-aggregate codegen` 这一套耦合实现,重构为 **三层正交架构**: + +``` +┌──────────────────┐ IRRow ┌─────────────┐ IRRow ┌────────────────┐ +│ InputAdapter │ ─────────▶ │ GroupOps │ ─────────▶ │ OutputAdapter │ +│ (Vector → IR) │ │ (IR ↔ Group)│ │ (IR → Vector) │ +└──────────────────┘ └─────────────┘ └────────────────┘ +``` + +三层之间的唯一传输格式是 **LLVM First-Class Aggregate** 类型: + +``` +IRRow_t = llvm::StructType::get(value_type, i1_ty) + = { T, i1 } // T 由 aggregate 自己决定,可以是复合类型 +``` + +`is_null` 永远在第二个字段,框架统一处理;`value_type` 内部结构对框架透明。 + +--- + +## 1. 当前问题(背景) + +落地前必须理解这些已存在的痛点,重构必须**逐项消除**。 + +### 1.1 数据结构无通用性 + +```cpp +// HashAggrJit.h —— 反例 +struct HashAggrJitDecodedInput { + const void* data; + const uint64_t* nulls; + // ... 写死了若干字段,新增 aggregate 类型就要扩字段 +}; +struct HashAggrJitOutput { /* 同上 */ }; +``` + +- **病症**:每加一种聚合 / 一种 vector encoding,就要改这两个结构 + 改 IR 的 hardcoded byte offset。 +- **影响**:ABI 双向耦合(C++ struct ↔ IR offset),任何字段重排都是坑。 + +### 1.2 Vector ↔ IR 与 IR ↔ Group 两段逻辑混在一起 + +每个 `XxxAggregate::codegenAddDense / codegenExtract` 同时做: +1. 从输入 vector decode 出值 +2. 在 IR 里做累加 / 比较 +3. 把结果按 group 内 memory layout 写回 + +→ 三件事完全不正交,维护成本爆炸;且每个 aggregate 都要重新写 vector decoding 逻辑。 + +### 1.3 复合 value 类型(avg)特殊化 + +avg intermediate 当前在多处直接写成三元组 `{f64 sum, i64 count, i1 is_null}`,把 null 处理跟 value 内部结构耦合在一起,框架 helper 无法复用。 + +--- + +## 2. 目标架构 + +### 2.1 核心抽象:`IRRow` + +**契约**: + +```cpp +// 框架级 invariant —— 所有 aggregate 共用 +IRRow_t(value_type) := llvm::StructType::get(value_type, i1Ty) +// ^^^^^^^^^^ ^^^^^ +// field 0 field 1 (is_null) +``` + +**关键决策(已与作者确认)**:当 `value_type` 本身是复合类型(如 avg 的 `{double sum, i64 count}`),**采用嵌套** `{{double, i64}, i1}`,不采用平铺 `{double, i64, i1}`。 + +理由(简版,详细对比见 §6): +- 嵌套保持 `IRRow = {T, i1}` 不变量,框架 helper 完全通用; +- 平铺让 framework 必须知道 T 内部 field 数量,破坏抽象; +- 二者 memory layout 完全相同(24B),lowering 后寄存器分配完全一致,**性能零差异**; +- 未来 stddev / HLL / array_agg 等复合 value 聚合都能复用同一套框架。 + +### 2.2 三层职责 + +| 层 | 输入 | 输出 | 不该做 | +|----|------|------|--------| +| **InputAdapter** | `BaseVector*` + row index (IR) | `IRRow`(in register) | 不感知 group memory | +| **GroupOps** | `IRRow` + `group ptr`(IR) | 写回 group / 产出新 `IRRow` | 不感知 vector encoding | +| **OutputAdapter** | `IRRow` + `BaseVector*` + row index | 写回 vector | 不感知 group memory | + +每一层都对其它两层透明 —— 通过 IRRow 的标准接口(见 §3)通信。 + +### 2.3 调用链对应关系 + +| 算子方法 | 三层调用链 | +|----------|-----------| +| `addRawInput` | `InputAdapter::read(rawVec, i)` → `GroupOps::accumulate(group, IRRow)` | +| `addIntermediateResults` | `InputAdapter::read(intVec, i)` → `GroupOps::merge(group, IRRow)` | +| `extractIntermediateResults` | `GroupOps::loadIntermediate(group)` → `OutputAdapter::write(intVec, i, IRRow)` | +| `extractResults` | `GroupOps::finalize(group)` → `OutputAdapter::write(finalVec, i, IRRow)` | +| `initGroup` | `GroupOps::init(group)` | + +--- + +## 3. 框架层 API(必须实现) + +新增文件:`velox/exec/jit/IRRow.h`、`velox/exec/jit/InputAdapter.h`、`velox/exec/jit/GroupOps.h`、`velox/exec/jit/OutputAdapter.h`(路径按 bolt 现有 jit 目录调整)。 + +### 3.1 `IRRow` —— 唯一传输格式 + +```cpp +class IRRow { + public: + // 类型构造:value_type 由 aggregate 决定 + static llvm::StructType* getType(llvm::IRBuilder<>& b, llvm::Type* value_type) { + return llvm::StructType::get(value_type, b.getInt1Ty()); + } + + // ---- 读 ---- + static llvm::Value* getValue(llvm::IRBuilder<>& b, llvm::Value* row) { + return b.CreateExtractValue(row, {0}); + } + static llvm::Value* getIsNull(llvm::IRBuilder<>& b, llvm::Value* row) { + return b.CreateExtractValue(row, {1}); + } + + // ---- 写 ---- + static llvm::Value* pack(llvm::IRBuilder<>& b, + llvm::Value* val, + llvm::Value* is_null) { + auto* ty = llvm::StructType::get(val->getType(), is_null->getType()); + auto* tmp = b.CreateInsertValue(llvm::UndefValue::get(ty), val, {0}); + return b.CreateInsertValue(tmp, is_null, {1}); + } + + static llvm::Value* withValue(llvm::IRBuilder<>& b, + llvm::Value* row, + llvm::Value* val) { + return b.CreateInsertValue(row, val, {0}); + } + static llvm::Value* withIsNull(llvm::IRBuilder<>& b, + llvm::Value* row, + llvm::Value* is_null) { + return b.CreateInsertValue(row, is_null, {1}); + } + + // ---- 复合 value 的二级访问:仅在 GroupOps 内部使用 ---- + static llvm::Value* getValueField(llvm::IRBuilder<>& b, + llvm::Value* row, + unsigned idx) { + return b.CreateExtractValue(row, {0, idx}); // 注意:嵌套 GEP + } +}; +``` + +**强约束**:除了 `IRRow` 这套 helper 之外,**任何代码不得**直接对 IRRow struct 做 `extractvalue` / `insertvalue` —— 一旦发现就是抽象泄漏。 + +### 3.2 `InputAdapter` —— 规范化输入描述 → IRRow + +```cpp +class InputAdapter { + public: + virtual ~InputAdapter() = default; + + // 在 codegen 阶段调用,返回 IR 类型(必须等于对应 aggregate 的 IRRow_t) + virtual llvm::StructType* irRowType(llvm::IRBuilder<>& b) const = 0; + + // 在 IRBuilder 当前位置生成读取代码:从 vector + index 读出一个 IRRow + virtual llvm::Value* read(llvm::IRBuilder<>& b, + llvm::Value* vector_ctx, + llvm::Value* row_idx) const = 0; +}; +``` + +**关键修正**:这里的 `InputAdapter` **不能**按原始 vector encoding(flat / constant / dictionary)拆成不同 JIT 实现。 + +原因是:同一个 compiled chunk 会反复运行在不同 batch 上,而 batch 的原始 encoding 可以变化。若按原始 +encoding 生成不同 IR,则 JIT module cache key 会被 batch 形态污染,代码无法收敛,甚至会退化成“按批次特化并反复编译”。 + +因此,正确边界应当是: + +```text +原始 Vector(flat/constant/dictionary/...) + │ + ▼ +GroupingSet / DecodedVector 先做批次级规范化 + │ + ▼ +Canonical decoded descriptor + { values, indices, nulls, decodedVector, rowField*... } + │ + ▼ +InputAdapter 只针对“规范化后的运行时描述”生成 IR +``` + +也就是说: + +- flat / constant / dictionary 的差异,应该在 **JIT 之前** 被 `DecodedVector` + runtime descriptor 吸收; +- JIT 内的 `InputAdapter` 面向的是**稳定 ABI**,而不是每个 batch 的原始 encoding; +- 这样生成出来的 IR 才能在不同 batch 上复用并保持收敛。 + +**实现一览**(最少需要这些 adapter,它们对应“规范化后的输入形态”,而不是原始 encoding): + +| Adapter | 处理的规范化形态 | 关键 IR 行为 | +|---------|------------------|--------------| +| `DecodedScalarInputAdapter` | 标量输入:`values + indices + nulls` | `index = indices[row]`;`gep + load` 数据;`bit test` top-level nulls;pack 成 IRRow | +| `DecodedRowInputAdapter` | ROW intermediate:`rowField* + decodedVector(fallback)` | 优先走 field raw pointers/nulls;必要时回退 row-field helper;在 IR 里构造嵌套 IRRow | +| `CountStarInputAdapter` | 无实参输入 | 直接产出固定非空 IRRow / 或由 GroupOps 特判 | + +> 每个 adapter **只负责自己**对应的“规范化输入 contract”到 IRRow 的转换,不涉及任何聚合语义。 +> +> 特别注意:`DecodedScalarInputAdapter` 生成的 IR 在 flat / constant / dictionary batch 上应完全相同;不同 batch +> 只通过 `indices/nulls/values` 的运行时内容体现差异,而不改变 IR 形状。 + +### 3.3 `GroupOps` —— IRRow ↔ Group + +**关键修正**:`GroupOps` 在 bolt 当前实现里,**不应该**被设计成“拥有 group layout / group size / group align”的抽象。 + +当前事实是: + +- group memory 由 `RowContainer + AggregateInfo + accumulator layout` 共同决定; +- JIT 侧真正拿到的是 `group ptr + HashAggrJitSlot`; +- 访问状态依赖 `slot.offset / slot.nullByte / slot.nullMask`,以及像 `JitAvgState` / `JitDecimal*State` + 这样的现有 state struct offset; +- 当前 `HashAggrJitOps` 也是围绕这个 contract 工作,而不是自己管理 group allocation。 + +因此,更贴近 bolt 现状的 `GroupOps` 应该是:**“在既有 slot/layout 之上生成 group state 读写 IR 的薄层 policy”**,而不是一个重新定义 group 存储协议的 owner。 + +```cpp +class GroupOps { + public: + virtual ~GroupOps() = default; + + // 该聚合的 intermediate value type(不含 is_null,框架自动包一层) + virtual llvm::Type* intermediateValueType(llvm::IRBuilder<>& b) const = 0; + virtual llvm::Type* finalValueType(llvm::IRBuilder<>& b) const = 0; + + // ---- codegen hooks ---- + // slot 提供当前 aggregate 在 group row 中的 offset/null-bit 等元数据。 + virtual void init(HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) const = 0; + + // 用 raw input 的 IRRow 累加进 group(对应当前 addRawInput)。 + virtual void accumulate(HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* input_irrow, + const HashAggrJitSlot& slot, + llvm::BasicBlock* nextBlock) const = 0; + + // 用 partial / intermediate 的 IRRow 合并进 group(对应当前 addIntermediateResults)。 + virtual void merge(HashAggrJitCodegen& codegen, + llvm::Value* group, + llvm::Value* intermediate_irrow, + const HashAggrJitSlot& slot, + llvm::BasicBlock* nextBlock) const = 0; + + // 从 group 读出 intermediate IRRow(extractIntermediateResults) + virtual llvm::Value* loadIntermediate(HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) const = 0; + + // 从 group 读出 final IRRow(extractResults) + virtual llvm::Value* finalize(HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot) const = 0; + + virtual bool canExtract(const HashAggrJitSlot& slot, + bool partialOutput) const = 0; +}; +``` + +**关键约束**: +1. `loadIntermediate` 返回的 IRRow 类型 = `IRRow::getType(intermediateValueType())`;`finalize` 返回 = `IRRow::getType(finalValueType())`。 +2. group 内 memory layout **不是 `GroupOps` 自己分配/注册**的;它依旧来源于现有 accumulator/state layout,`GroupOps` 只是通过 `slot + state field offset` 去访问。 +3. 第一阶段 `GroupOps` 可以是当前 `HashAggrJitOps` 的**薄 facade**:先把“状态读写逻辑”从 aggregate ops 中理顺,不要求第一步就重写整个 JIT chunk 生成框架。 +4. **null 处理统一在这一层完成**:`accumulate / merge` 必须显式处理 `IRRow::getIsNull(input)`,框架不再依赖任何外部状态。 +5. `nextBlock` 仍作为参数保留,是为了兼容当前 `genAddDenseIR(...)` 的控制流拼装方式;不要为了追求接口漂亮而强行重写外层 loop/branch 骨架。 + +### 3.3.1 与当前 `HashAggrJitOps` 的映射 + +为了降低迁移风险,建议第一阶段直接保持与现有 `HashAggrJitOps` 一一对应: + +| 当前接口 | 收敛后的职责 | +|----------|--------------| +| `initGroup` | `GroupOps::init` | +| `addRawInput` | `InputAdapter::read(raw)` → `GroupOps::accumulate` | +| `addIntermediateResults` | `InputAdapter::read(intermediate)` → `GroupOps::merge` | +| `canExtract` | `GroupOps::canExtract` | +| `extract` | `GroupOps::loadIntermediate/finalize` → `OutputAdapter::write` | + +也就是说,**第一步不是删掉 `HashAggrJitOps`,而是让它退化为一个桥接层**: + +- 对外仍维持当前 JIT chunk 代码生成入口; +- 对内逐步把输入读取 / group 状态访问 / 输出写回转发到新三层; +- 等所有 aggregate 都迁完后,再决定是否彻底折叠旧表结构。 + +### 3.4 `OutputAdapter` —— IRRow → Vector + +```cpp +class OutputAdapter { + public: + virtual ~OutputAdapter() = default; + + virtual llvm::StructType* irRowType(llvm::IRBuilder<>& b) const = 0; + + // 把 IRRow 写入 vector[row_idx] + virtual void write(llvm::IRBuilder<>& b, + llvm::Value* vector_ctx, + llvm::Value* row_idx, + llvm::Value* irrow) const = 0; +}; +``` + +输出端通常只需要 `FlatOutputAdapter` 和 `RowOutputAdapter`(写复合 intermediate)。 + +--- + +## 4. 各聚合落地示例 + +### 4.1 `sum` / `sum`(最简单) + +```cpp +class SumGroupOps : public GroupOps { + llvm::Type* intermediateValueType(IRBuilder& b) const override { return b.getInt64Ty(); } + llvm::Type* finalValueType(IRBuilder& b) const override { return b.getInt64Ty(); } + + void init(HashAggrJitCodegen& codegen, + Value* group, + const HashAggrJitSlot& slot) const override { + codegen.setAccumulatorNull(group, slot); + codegen.storeValue(group, codegen.builder().getInt64Ty(), slot.offset, + codegen.builder().getInt64(0)); + } + + void accumulate(HashAggrJitCodegen& codegen, + Value* group, + Value* in, + const HashAggrJitSlot& slot, + BasicBlock*) const override { + auto& b = codegen.builder(); + auto* in_null = IRRow::getIsNull(b, in); + auto* in_val = IRRow::getValue(b, in); + // if (!in_null) { sum += in_val; is_null = false; } + BasicBlock *if_t = ..., *cont = ...; + b.CreateCondBr(b.CreateNot(in_null), if_t, cont); + b.SetInsertPoint(if_t); + auto* old = codegen.loadValue(group, b.getInt64Ty(), slot.offset); + codegen.storeValue(group, b.getInt64Ty(), slot.offset, b.CreateAdd(old, in_val)); + codegen.clearAccumulatorNull(group, slot); + b.CreateBr(cont); + b.SetInsertPoint(cont); + } + + void merge(HashAggrJitCodegen& codegen, + Value* group, + Value* in, + const HashAggrJitSlot& slot, + BasicBlock* next) const override { + accumulate(codegen, group, in, slot, next); + } + + Value* loadIntermediate(HashAggrJitCodegen& codegen, + Value* group, + const HashAggrJitSlot& slot) const override { + auto& b = codegen.builder(); + return IRRow::pack( + b, + codegen.loadValue(group, b.getInt64Ty(), slot.offset), + codegen.isAccumulatorNull(group, slot)); + } + Value* finalize(HashAggrJitCodegen& codegen, + Value* group, + const HashAggrJitSlot& slot) const override { + return loadIntermediate(codegen, group, slot); + } + + bool canExtract(const HashAggrJitSlot&, bool) const override { + return true; + } +}; +``` + +### 4.2 `avg`(复合 value,**采用嵌套**) + +```cpp +class AvgGroupOps : public GroupOps { + // intermediate value = { double sum, i64 count };is_null 由框架包外层 + llvm::Type* intermediateValueType(IRBuilder& b) const override { + return llvm::StructType::get(b.getDoubleTy(), b.getInt64Ty()); + } + llvm::Type* finalValueType(IRBuilder& b) const override { return b.getDoubleTy(); } + + // 注意:这里不是重新定义 group layout,而是复用现有 accumulator/state layout。 + // 当前 bolt 中 avg 仍应与 JitAvgState / slot.offset / kAvgCountOffset 保持一致, + // 避免第一阶段重构把 state ABI 一起打散。 + + void init(HashAggrJitCodegen& codegen, + Value* group, + const HashAggrJitSlot& slot) const override { + ... // 对齐当前 compileAvgInitGroup:setAccumulatorNull + sum/count 初始化 + } + + // raw input: IRRow_t = { double, i1 } + void accumulate(HashAggrJitCodegen& codegen, + Value* group, + Value* in, + const HashAggrJitSlot& slot, + BasicBlock*) const override { + auto& b = codegen.builder(); + auto* in_null = IRRow::getIsNull(b, in); + auto* in_val = IRRow::getValue(b, in); + // if (!in_null) { sum += val; count += 1; } + ... + } + + // intermediate: IRRow_t = { {double, i64}, i1 } + void merge(HashAggrJitCodegen& codegen, + Value* group, + Value* in, + const HashAggrJitSlot& slot, + BasicBlock* nextBlock) const override { + auto& b = codegen.builder(); + auto* in_null = IRRow::getIsNull(b, in); + auto* part_sum = IRRow::getValueField(b, in, 0); // double + auto* part_count = IRRow::getValueField(b, in, 1); // i64 + // if (!in_null) { sum += part_sum; count += part_count; } + ... + } + + Value* loadIntermediate(HashAggrJitCodegen& codegen, + Value* group, + const HashAggrJitSlot& slot) const override { + auto& b = codegen.builder(); + auto* count = loadCount(b, group); + auto* sum = loadSum(b, group); + auto* is_null = b.CreateICmpEQ(count, b.getInt64(0)); + // 构造嵌套 struct { double, i64 } + auto* inner_ty = intermediateValueType(b); + auto* inner = b.CreateInsertValue(UndefValue::get(inner_ty), sum, {0}); + inner = b.CreateInsertValue(inner, count, {1}); + return IRRow::pack(b, inner, is_null); + } + + Value* finalize(HashAggrJitCodegen& codegen, + Value* group, + const HashAggrJitSlot& slot) const override { + auto& b = codegen.builder(); + auto* count = loadCount(b, group); + auto* sum = loadSum(b, group); + auto* is_null = b.CreateICmpEQ(count, b.getInt64(0)); + auto* avg = b.CreateFDiv(sum, b.CreateSIToFP(count, b.getDoubleTy())); + return IRRow::pack(b, avg, is_null); + } + + bool canExtract(const HashAggrJitSlot& slot, bool partialOutput) const override { + return ...; // 第一阶段直接镜像当前 canCompileAvgExtract 语义 + } +}; +``` + +`InputAdapter` 端只需要: +- raw 输入 → `DecodedScalarInputAdapter` (IRRow = `{double, i1}`) +- intermediate 输入 → `DecodedRowInputAdapter`,读取规范化后的 `rowField* / decodedVector(fallback)`,自动构造嵌套 IRRow + +### 4.3 `count` + +```cpp +class CountGroupOps : public GroupOps { + llvm::Type* intermediateValueType(IRBuilder& b) const override { return b.getInt64Ty(); } + llvm::Type* finalValueType(IRBuilder& b) const override { return b.getInt64Ty(); } + // count 永远不是 null,is_null 字段恒为 false(LLVM 会优化掉) + // ... +}; +``` + +> 实现上应继续贴合当前 bolt:`count(*)` 的 raw-input 路径本质是 `+1`,而非真的去读取一个输入列; +> merge 路径则读取 intermediate bigint。不要为了统一接口而把 `count(*)` 硬塞进一个虚构输入列模型里。 + +### 4.4 `min / max` + +类似 sum,把 `Add` 换成 `select(cmp, old, new)` 即可。 + +### 4.5 `stddev`(前瞻验证,体现可扩展性) + +```cpp +llvm::Type* intermediateValueType(IRBuilder& b) const override { + return llvm::StructType::get(b.getInt64Ty(), // count + b.getDoubleTy(), // mean + b.getDoubleTy()); // M2 +} +// IRRow_t = { {i64, double, double}, i1 },框架完全无需改动 +``` + +--- + +## 5. 与现有代码的对接 + +### 5.1 删除项 + +| 文件 / 符号 | 处置 | +|-------------|------| +| `struct HashAggrJitDecodedInput` | 第一阶段**不删除**;先把它收敛为 canonical decoded descriptor,供 `InputAdapter` 消费;所有 aggregate 迁移完成后再决定是否改名/瘦身 | +| `struct HashAggrJitOutput` | 第一阶段**不删除**;先把它收敛为 canonical output descriptor,供 `OutputAdapter` 消费;待 extract 全迁完后再决定是否改名/瘦身 | +| 各 `XxxAggregate::codegenAddDense` 中关于 vector decoding 的代码 | 迁到 `InputAdapter` | +| 各 `XxxAggregate::codegenAddDense` 中关于 group rw 的代码 | 迁到 `GroupOps` | +| `Aggregate::numNulls_` 的更新依赖(JIT path) | 删除依赖 | +| 任何 `extractvalue` / `insertvalue` 直接对 IRRow 做的代码 | 替换成 `IRRow::*` helper | + +> 额外说明:第一阶段的目标是**理顺职责边界**,不是立即改变 group row 的底层存储协议; +> `slot.offset/nullByte/nullMask` 与现有 state struct offset 仍然是合法的迁移期依赖。 + +### 5.2 新增项 + +``` +velox/exec/jit/ +├── IRRow.h +├── InputAdapter.h +├── input_adapters/ +│ ├── DecodedScalarInputAdapter.h +│ ├── DecodedRowInputAdapter.h +│ └── CountStarInputAdapter.h +├── GroupOps.h +├── group_ops/ +│ ├── SumGroupOps.{h,cpp} +│ ├── CountGroupOps.{h,cpp} +│ ├── AvgGroupOps.{h,cpp} +│ ├── MinMaxGroupOps.{h,cpp} +│ └── ... +├── OutputAdapter.h +└── output_adapters/ + ├── FlatOutputAdapter.h + └── RowOutputAdapter.h +``` + +### 5.3 单测要求 + +新增测试 `HashAggrJitFrameworkTest.cpp`,必须覆盖: + +1. `IRRow::pack/getValue/getIsNull` 在简单类型与嵌套类型上 round-trip。 +2. `DecodedScalarInputAdapter` 在 flat/constant/dictionary 三种 batch encoding 上生成**同一形状 IR**,并通过不同的 `indices/nulls/values` runtime 内容得到正确结果。 +3. `DecodedRowInputAdapter` 在 row-field raw fast path 与 helper fallback 两条路径上结果一致。 +4. 每个 `GroupOps`:init → accumulate(若干 raw + 若干 null) → loadIntermediate → merge(到另一 group) → finalize 与 reference 实现一致。 +5. **专项 null 测试**:所有输入都是 null 时,`finalize` 必须返回 `is_null = true`。 +6. avg intermediate 必须验证 IRRow 的 LLVM type 字面就是 `{ {double, i64}, i1 }`(而非平铺)。 + +--- + +## 6. 嵌套 vs 平铺:决策记录(avg 等复合 value) + +| 维度 | 嵌套 `{{double,i64}, i1}` ✅ | 平铺 `{double, i64, i1}` ❌ | +|------|----------------------------|---------------------------| +| `IRRow = {T, i1}` invariant | 保持 | 破坏 | +| `IRRow::getValue / getIsNull` 是否通用 | 是(`{0}` / `{1}`) | 否,avg 要 special case | +| 框架对 T 的内部结构 | 不感知 | 必须知道 field 数 | +| 新增复合 value 聚合(stddev/HLL/...) | 0 改动 | 框架每次都要扩展 | +| Memory layout | 24B(offset 0/8/16) | 24B(offset 0/8/16),完全相同 | +| LLVM lowering 性能 | 经 SROA/InstCombine 后与平铺一致 | 与嵌套一致 | +| IR 可读性 | 略冗长(多一层 `{0,k}`) | 更短 | + +**结论**:嵌套方案在抽象一致性、可扩展性上完胜,且无任何性能代价。**全部聚合统一采用嵌套布局。** + +--- + +## 7. 落地步骤(建议 PR 顺序) + +为了控制每个 MR 的 diff 体积,推荐拆 4 个 MR 提交: + +| # | MR 标题 | 范围 | 依赖 | +|---|---------|------|------| +| 1 | `[jit] Introduce IRRow + canonical Input/Output descriptors + GroupOps facade` | 在现有 `HashAggrJitOps` 外围引入三层抽象与单测,不改 chunk ABI | 无 | +| 2 | `[jit] Migrate sum/count/min/max onto GroupOps + Adapter internals` | 先迁简单标量聚合,外部入口保持兼容 | #1 | +| 3 | `[jit] Migrate avg with nested intermediate IRRow` | avg 落地嵌套方案,保留现有 state layout 与 extract 语义 | #2 | +| 4 | `[jit] Migrate decimal sum/avg and optionally shrink legacy tables` | decimal 收口,并视情况瘦身旧 descriptor / ops table | #3 | + +每个 MR 都要: +- 跑通现有 hash aggr e2e 测试集(重点覆盖含 null 输入的 case)。 +- 跑 micro benchmark 对比,必须 ≥ 当前 `9a65fd2` 性能。 +- LLVM IR dump(`-dump_ir`)肉眼检查 SROA 后是否消除了 alloca。 + +--- + +## 8. 验收标准(Definition of Done) + +- [ ] `HashAggrJitDecodedInput / HashAggrJitOutput` 至少已收敛为 canonical descriptor;若仍保留,也不再继续按新聚合需求横向扩字段。 +- [ ] 任何 IRRow 字段访问只能经过 `IRRow::*` helper(grep `extractvalue.*IRRow|insertvalue.*IRRow` 应为零)。 +- [ ] JIT path 不再读写 `Aggregate::numNulls_`。 +- [ ] avg intermediate 的 LLVM type 等于 `{{double, i64}, i1}`(单测断言)。 +- [ ] 新增 stddev 或任意复合 value 聚合时,**不需要修改 InputAdapter/OutputAdapter/IRRow 等框架边界定义**;最多新增对应 `GroupOps` / state helper。 +- [ ] e2e 性能不回退,TPC-H Q1(avg 重灾区)持平或提升。 + +--- + +## 9. 备注 + +- 本方案与上游 Velox 无 conflict —— Velox 没有 hash aggr JIT,bolt 这块是独立分叉。 +- 如果未来引入 PartialFinal 优化、ROW vector 嵌套加深,IRRow 接口无需改动。 +- 复合 value 聚合超过 3 层嵌套(极少见)时,建议在 `IRRow` 上提供 path-based getter(`getValueByPath({0,1,2})`),但不在本次重构范围内。 From c027478af716861a21af2396fdee741862213c02 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Thu, 11 Jun 2026 23:46:41 +0800 Subject: [PATCH 48/98] update plan doc --- hashaggr_jit_refactor_plan.md | 765 ++++++++++++++++++++++++++++++++++ 1 file changed, 765 insertions(+) diff --git a/hashaggr_jit_refactor_plan.md b/hashaggr_jit_refactor_plan.md index 8e9cf86c8..86a202a70 100644 --- a/hashaggr_jit_refactor_plan.md +++ b/hashaggr_jit_refactor_plan.md @@ -567,6 +567,226 @@ velox/exec/jit/ ## 7. 落地步骤(建议 PR 顺序) +### 7.0 首个最小可实施 patch(本轮直接落地) + +为了避免第一步就同时改动 chunk ABI、aggregate ops table、descriptor 字段和 benchmark 口径,首个 patch 只做**最小且可验证**的框架落点: + +#### Patch-1 范围 + +1. 新增 `IRRow` helper(建议先放在现有 `bolt/jit/aggregation/` 目录下,而不是一开始新建整套 framework 目录): + - `getType` + - `pack` + - `getValue` + - `getIsNull` + - `withValue` + - `withIsNull` + - `getValueField` +2. 新增对应单测,至少覆盖: + - 标量 value 的 round-trip; + - 嵌套 value(如 `{{double, i64}, i1}`)的 round-trip; + - `withValue / withIsNull` 的覆盖更新语义; +3. **不修改**当前 `HashAggrJitChunk` ABI; +4. **不修改** `HashAggrJitDecodedInput / HashAggrJitOutput` 结构; +5. **不迁移**任何 aggregate 到新三层,只把 `IRRow` 作为第一块可复用基建先落进去。 + +#### Patch-1 预期收益 + +- 为后续 `GroupOps::loadIntermediate/finalize` 提供统一返回协议; +- 为 avg / decimal avg 这类复合 value 的嵌套 IRRow 建立稳定 helper; +- 先把最容易验证、最不影响性能的部分单独落地,降低后续 patch 风险。 + +#### Patch-1 验证方式 + +- 编译 `bolt_thrustjit`; +- 若当前配置包含测试,则额外编译并运行 `bolt_thrustjit_test`; +- 该 patch 不应改变任何现有 hash aggr JIT 生成 IR 的行为与性能。 + +#### Patch-2 范围(紧接 Patch-1) + +第二个最小 patch 继续保持“不改 ABI、不改 chunk 骨架”的原则,只做**标量输入读取的一层内部收口**: + +1. 新增一个极薄的 `DecodedScalarInputAdapter` helper; +2. 第一阶段只提供 `readKnownNotNull(...)`: + - 适用于当前外层控制流已经完成 top-level null 过滤后的路径; + - 直接把 `loadDecodedValue(...)` 的结果打包成 `IRRow`,并把 `is_null` 固定为 `false`; +3. 仅选择 `sum` 作为第一个接入对象,把 `SumOps.cpp` 中对标量输入的直接读取改成通过该 helper; +4. 不改 `HashAggrJitChunk`、`genAddDenseIR(...)`、`HashAggrJitDecodedInput` ABI; +5. `readNullable(...)`、`DecodedRowInputAdapter`、`GroupOps facade` 留到后续 patch。 + +#### Patch-2 预期收益 + +- 验证“InputAdapter 是内部 codegen helper,而不是按 batch encoding 分裂 ABI”的设计方向; +- 让后续 `sum/minmax/avg raw-input` 迁移时有统一入口; +- 继续保证热路径不回退:外层 null 分支与现有 tight loop 骨架保持不变。 + +#### Patch-2 验证方式 + +- 编译 `bolt_thrustjit_test`; +- 运行新增 `IRRow` / `DecodedScalarInputAdapter` 相关测试; +- 确认 `SumOps.cpp` 只是把“直接 load decoded scalar”替换成 helper,不改变现有 null 过滤与算子语义。 + +#### Patch-3 范围(延续 Patch-2) + +第三个最小 patch 继续沿用同一迁移策略,把 `DecodedScalarInputAdapter` 的使用从 `sum` 扩展到 `min/max`: + +1. 不新增 ABI; +2. 不修改 `DecodedScalarInputAdapter` 接口; +3. 仅把 `MinMaxOps.cpp` 中 raw-input 标量读取切换为 `DecodedScalarInputAdapter::readKnownNotNull(...)`; +4. 保持当前外层 null 过滤、NaN 处理和比较逻辑不变; +5. 不触碰 merge row-field 路径,不引入 nullable adapter。 + +#### Patch-3 预期收益 + +- 让 `sum/min/max` 三个最基础的标量 raw-input 聚合统一走同一条内部读取入口; +- 进一步验证“InputAdapter 是内部 codegen helper,而不是新的运行时 ABI”; +- 为后续批量迁移其它标量聚合打样。 + +#### Patch-3 验证方式 + +- 编译 `bolt_thrustjit_test`; +- 运行现有 `IRRow` / `DecodedScalarInputAdapter` 相关测试,确认基建未破坏; +- 确认 `MinMaxOps.cpp` 只替换输入读取入口,不改变 min/max 的比较、NaN 语义和 extract 逻辑。 + +#### Patch-4 范围(先补 nullable contract,不立刻接算子) + +第四个最小 patch 只补 `DecodedScalarInputAdapter` 的 nullable contract,本轮**先落 helper 与单测,不接任何 aggregate**: + +1. 新增 `DecodedScalarInputAdapter::readNullable(...)`; +2. helper 负责: + - 读取 `nulls` 指针; + - 在 `nulls == nullptr` 时走非空快速路径; + - 在 `nulls != nullptr` 时按 row bit 判断是否为 null; + - 返回 `IRRow{value, is_null}`; +3. null 行上不要求读取真实 payload,允许写入 typed zero 作为占位值; +4. 本 patch **不修改** `sum/min/max/count/avg` 等 aggregate; +5. 通过一个可执行的 JIT 单测验证 nullable 语义,而不是只做类型级验证。 + +#### Patch-4 预期收益 + +- 正式建立 `DecodedScalarInputAdapter` 的 nullable 语义 contract; +- 为后续把外层 null 分支逐步内聚到 InputAdapter 提供基础; +- 先用单测把“null 行不读取真实 payload、仅传递 is_null”这件事定下来,避免后续改算子时语义摇摆。 + +#### Patch-4 验证方式 + +- 编译 `bolt_thrustjit_test`; +- 运行新增 `readNullable` JIT 语义测试: + - `nulls == nullptr` 时返回真实 value; + - `nulls` bit 置位时返回 `is_null = true` 对应结果; + - 非 null 行仍返回真实 value; +- 继续跑已有 `IRRow` / `DecodedScalarInputAdapter` 基础测试,确认旧 contract 不回退。 + +#### Patch-5 范围(让 sum 消费 nullable IRRow,但先保留外层 null branch) + +第五个最小 patch 开始让真实 aggregate 消费 `readNullable(...)` 产出的 `IRRow`,但仍然坚持“不一次性收掉外层控制流”: + +1. `SumOps.cpp` 的 add/merge 路径统一先读取 `inputRow = DecodedScalarInputAdapter::readNullable(...)`; +2. `sum` 内部通过 `IRRow::getIsNull(inputRow)` 决定是否跳过累加; +3. 现有 `genAddDenseIR(...)` 的 top-level null 过滤分支**保留不动**; +4. 这意味着本 patch 的行为应与当前逻辑保持一致,只是把 `sum` 的内部消费协议收口到 nullable IRRow; +5. 本 patch 不要求立即让外层 null 分支失效或删除。 + +#### Patch-5 预期收益 + +- 第一次验证“真实 aggregate 可以消费 nullable IRRow contract”; +- 为后续是否收掉外层 null 分支提供对照基线; +- 把 `sum` 变成第一个同时兼容 known-not-null 与 nullable 读取 contract 的算子样板。 + +#### Patch-5 验证方式 + +- 编译 `bolt_thrustjit_test`; +- 运行一个最小 `sum-like` JIT 语义测试:null 行返回旧 accumulator,不为 null 时返回 `old + value`; +- 继续运行已有 `IRRow` / `DecodedScalarInputAdapter` 基础测试,确认 helper contract 不回退; +- 确认 `SumOps.cpp` 仍未修改外层 null 过滤框架,仅改变输入消费方式。 + +#### Patch-6 范围(让 min/max 同样消费 nullable IRRow) + +第六个最小 patch 把 Patch-5 在 `sum` 上验证过的模式复制到 `min/max`: + +1. `MinMaxOps.cpp` 的 update 路径改为先读取 `inputRow = DecodedScalarInputAdapter::readNullable(...)`; +2. 通过 `IRRow::getIsNull(inputRow)` 显式跳过 null 行的比较与写回; +3. 非 null 行仍执行原有 min/max 比较、NaN 处理与 accumulator null 清除逻辑; +4. 现有 `genAddDenseIR(...)` 的 top-level null 过滤分支保留不动; +5. 行为与当前实现保持一致,仅把输入消费协议收口到 nullable IRRow。 + +#### Patch-6 预期收益 + +- 让 `sum/min/max` 统一以 nullable IRRow contract 消费输入; +- 进一步验证“先收口消费协议、暂不删外层 null branch”这一渐进模式在带比较/NaN 语义的算子上同样成立; +- 为后续真正收掉外层 null 分支留出一致的算子基线。 + +#### Patch-6 验证方式 + +- 编译 `bolt_thrustjit_test`; +- 运行已有 `IRRow` / `DecodedScalarInputAdapter` / `sum-like` 测试,确认 contract 不回退; +- 确认 `MinMaxOps.cpp` 仅替换输入消费方式,不改变比较、NaN 与 extract 语义,也未触碰外层 null 框架。 + +#### Patch-7 范围(引入最小 FlatOutputAdapter 并让 sum extract 接入) + +前面几个 patch 都在 input 端收口,本 patch 开始对称地在 output 端引入第一块 helper: + +1. 新增最小 `FlatOutputAdapter`(同样是 codegen-time helper,不引入任何运行时 ABI); +2. 只提供 `writeFromIRRow(codegen, output, row, slot, irRow)`: + - 从 `IRRow` 取 value 与 i1 `is_null`; + - 把 `is_null` zext 到 i8; + - 复用现有 `emitFlatValue(...)` 写回 flat 输出; +3. 让 `SumOps.cpp` 的 extract 先用 `IRRow::pack(value, is_null)` 组装,再通过 `FlatOutputAdapter::writeFromIRRow(...)` 写回; +4. 不修改 `HashAggrJitOutput` 结构与 `genExtractIR(...)` 骨架; +5. 仅 `sum` 接入,`min/max/count/avg/decimal` 的 extract 暂不动。 + +#### Patch-7 预期收益 + +- 让 output 端也有一个与 `IRRow` 对齐的统一写回入口; +- 验证“OutputAdapter 也是内部 codegen helper,而非新 ABI”这一设计方向; +- 为后续把更多 extract 收口到 `FlatOutputAdapter` / `RowOutputAdapter` 打样。 + +#### Patch-7 验证方式 + +- 编译 `bolt_thrustjit_test`; +- 编译并运行 `bolt_aggregates_test` 的 `SumTest` 相关用例,确认 sum extract 行为未回归; +- 确认 `SumOps.cpp` 的 extract 仅改写写回入口,flat 输出语义与 null 位写入保持一致。 + +#### Patch-8 范围(min/max 与 count 的 extract 也接入 FlatOutputAdapter) + +继续把 output 端收口扩展到其余标量聚合: + +1. `MinMaxOps.cpp` 的 extract 改为:`IRRow::pack(value, isAccumulatorNull)` → `FlatOutputAdapter::writeFromIRRow(...)`; +2. `CountOps.cpp` 的 extract 改为:`IRRow::pack(value, false)` → `FlatOutputAdapter::writeFromIRRow(...)`(count 永不为 null); +3. 不修改 `HashAggrJitOutput` 结构与 `genExtractIR(...)` 骨架; +4. 行为与当前实现保持一致,只把写回入口统一到 `FlatOutputAdapter`。 + +#### Patch-8 预期收益 + +- 让 `sum/min/max/count` 四个标量聚合的 extract 全部走统一 output 入口; +- 进一步压实“OutputAdapter 是内部 codegen helper”的方向; +- 为后续 partial avg / decimal 等复杂 extract 的 RowOutputAdapter 收口铺路。 + +#### Patch-8 验证方式 + +- 编译 `bolt_thrustjit_test`、`bolt_aggregates_test`; +- 运行 `CountAggregationTest` 与 `SumTest` 相关用例,确认 extract 未引入新回归; +- 已知 `MinMaxTest` 三个 JIT 对照用例失败,本轮先忽略,仅确认未新增其它失败。 + +#### Patch-9 范围(avg 的 final extract 接入 FlatOutputAdapter) + +avg 的 final extract 本质就是把 `avg = sum / count` 写回一个 flat double,与 sum/min/max 同形态,因此先收口 final 分支: + +1. `AvgOps.cpp` 的 `compileAvgExtract` 在 `partialOutput == false` 分支改为:`IRRow::pack(avg, is_null)` → `FlatOutputAdapter::writeFromIRRow(...)`; +2. partial avg 的 ROW 输出(`emitPartialAvgResult`)暂不动,留待后续 `RowOutputAdapter`; +3. 不修改 `HashAggrJitOutput` 结构与 `genExtractIR(...)` 骨架; +4. final avg 的 `count == 0 -> null`、divide 语义保持不变。 + +#### Patch-9 预期收益 + +- 让 sum/min/max/count/avg(final) 的 flat extract 全部走统一 output 入口; +- 把 partial(ROW)与 final(flat)两类 output 路径的边界显式化,为 `RowOutputAdapter` 铺路。 + +#### Patch-9 验证方式 + +- 编译 `bolt_thrustjit_test`、`bolt_aggregates_test`; +- 运行 `AverageAggregationTest` 相关用例,确认 final avg extract 行为未回归; +- 确认 `AvgOps.cpp` 仅改写 final 分支写回入口,partial ROW 输出与 divide/null 语义不变。 + 为了控制每个 MR 的 diff 体积,推荐拆 4 个 MR 提交: | # | MR 标题 | 范围 | 依赖 | @@ -599,3 +819,548 @@ velox/exec/jit/ - 本方案与上游 Velox 无 conflict —— Velox 没有 hash aggr JIT,bolt 这块是独立分叉。 - 如果未来引入 PartialFinal 优化、ROW vector 嵌套加深,IRRow 接口无需改动。 - 复合 value 聚合超过 3 层嵌套(极少见)时,建议在 `IRRow` 上提供 path-based getter(`getValueByPath({0,1,2})`),但不在本次重构范围内。 + +--- + +## 10. InputAdapter 虚接口重构设计(Approach-1 落地稿) + +这一章收敛“最终想要的效果”——**真正建立 `InputAdapter -> GroupOps -> OutputAdapter` 三层正交骨架,并最终删除 `HashAggrJitDecodedInput`**,而不是继续在旧 descriptor 上横向打补丁。 + +### 10.1 目标边界 + +本轮 InputAdapter 重构必须同时满足: + +1. `InputAdapter` 提供**虚函数接口**; +2. adapter 在**构造时直接接受 vector 输入**,而不是接受 `HashAggrJitDecodedInput` 这类中间拼装物; +3. 第一层实现只分两类: + - `ScalarInputAdapter` + - `RowInputAdapter` +4. JIT IR **不能按 flat / constant / dictionary 三种 encoding 分叉生成**;输入 encoding 差异必须在 adapter 内部先被吸收、收敛; +5. 热路径性能不能回退:不能把“每行一次 helper / 每行一次虚调用”重新引回 add-dense loop; +6. 完成后可以删除 `HashAggrJitDecodedInput`,后续新增聚合也不允许再给它加字段。 + +### 10.2 分层职责 + +#### A. InputAdapter:只负责“把 vector 解释成 IRRow 输入契约” + +InputAdapter 的职责是: + +- 接受 batch 内真实的 `BaseVector` / `RowVector` 输入; +- 在 adapter 内部完成 decode / flatten / indices/nulls 收敛; +- 对 JIT 暴露**稳定、encoding 无关**的 runtime payload; +- 对 codegen 暴露“如何从该 payload 读出 `IRRow`”的统一接口。 + +它**不负责**: + +- 聚合 state layout; +- add / merge / extract 语义; +- 输出 vector 写回; +- 聚合专有字段语义(例如 avg 的 `sum/count`、decimal 的 `isEmpty`)。 + +#### B. GroupOps:只负责“消费/产生 IRRow” + +GroupOps 只看: + +- 输入:`IRRow` 或嵌套 `IRRow` +- 状态:`group + slot.offset` +- 输出:`IRRow` + +也就是说,`sum/min/max/count/avg/decimal` 的差异只体现在各自 ops/state helper 中; +**GroupOps 不拥有 InputAdapter / OutputAdapter 的 ABI,也不拥有 group layout 定义权**。 + +#### C. OutputAdapter:只负责“把 IRRow 写回结果 vector” + +OutputAdapter 只做两件事: + +- flat output:写一个标量 `IRRow`; +- row output:把 `IRRow` 的每个 child field 和 top-level null 写回。 + +它不应理解“这是 avg 的 2-field row”或“这是 decimal sum 的 3-field row”; +字段个数/字段类型来自 `IRRow` 的 payload type,本身不携带聚合语义。 + +### 10.3 运行时对象模型 + +最终运行时不再构造 `HashAggrJitDecodedInput`,而是构造 adapter 对象: + +```cpp +class InputAdapter { + public: + virtual ~InputAdapter() = default; + + // 供 batch 准备阶段调用;完成 decode / flatten / child adapter 建立。 + virtual void prepare() = 0; + + // 返回稳定的、可传入 JIT add_dense 的 runtime payload。 + virtual const void* runtime() const = 0; + + // 返回与该 adapter 对应的 codegen 节点。 + virtual const InputAdapterCodegen& codegen() const = 0; +}; + +class ScalarInputAdapter final : public InputAdapter { ... }; +class RowInputAdapter final : public InputAdapter { ... }; +``` + +这里的关键点是: + +- **虚函数只发生在 batch 准备阶段**; +- JIT 热循环不做 virtual dispatch; +- JIT add-dense 看到的仍然是 `char**` / `void**` 风格的 runtime payload 数组,只是 payload 的拥有者从旧 struct 变成了 adapter; +- 因此性能上仍保持“每行直接 load 指针/indices/nulls”的 fast path。 + +### 10.4 Runtime payload 形状 + +为了替换 `HashAggrJitDecodedInput`,需要把“adapter 对象”和“JIT 可直接 load 的 POD payload”解耦: + +```cpp +struct ScalarInputRuntime { + const void* values; + const int32_t* indices; + const uint64_t* nulls; +}; + +struct RowInputRuntime { + const uint64_t* nulls; + const ScalarInputRuntime* const* children; // scalar child runtimes + int32_t numChildren; +}; +``` + +约束如下: + +- `ScalarInputRuntime` 对应今天 canonical decoded descriptor 的标量子集; +- `RowInputRuntime` 不再内嵌 `rowField0Values/rowField1Values/...` 这种聚合专有字段; +- `RowInputRuntime` 也不再保留 `indices`:row 自身不再承载 dictionary/constant 等 wrapping,若 child 仍需索引映射,应下沉到 child 自己的 scalar runtime; +- 当前阶段 `RowInputRuntime.children[i]` 直接指向 `ScalarInputRuntime`;也就是说,本轮只支持 **row-of-scalars**,不引入递归 row child; +- 若某些 merge fast path 需要 child flat raw 指针,应该由 `ScalarInputAdapter` 自身保证其 runtime payload 已经是可直接 load 的 canonical scalar 形态,而不是再给顶层 row runtime 增加“field0/field1 特例字段”。 + +### 10.5 Codegen 侧接口 + +codegen 层不再把“输入读取”硬编码成 `loadDecodedValue / loadDecodedRowField*` 这一组围绕 `HashAggrJitDecodedInput` 的 offset 访问,而是收敛成: + +```cpp +class InputAdapterCodegen { + public: + virtual ~InputAdapterCodegen() = default; + // 返回 IRRow 中 payload 的 LLVM type。 + virtual llvm::Type* llvmValueType(HashAggrJitCodegen& codegen) const = 0; + + virtual llvm::Value* readIRRow( + HashAggrJitCodegen& codegen, + llvm::Value* runtime, + llvm::Value* row, + const HashAggrJitSlot& slot) const = 0; +}; +``` + +第一阶段只需要两个实现: + +- `ScalarInputAdapterCodegen` + - 从 `ScalarInputRuntime` 读出 `IRRow` + - 覆盖今天 `DecodedScalarInputAdapter::readKnownNotNull/readNullable` 的职责 +- `RowInputAdapterCodegen` + - 从 `RowInputRuntime` 先读出 `children[i]` 对应的 `ScalarInputRuntime*` + - 再通过内部持有的 scalar child readers 读取 child 的 value/null + - 最后组装顶层 `IRRow` + - 覆盖今天 `loadDecodedRowField / loadDecodedRowFieldBool` 这组“按 field 特判”的路径 + +这样 IR 收敛点就从“旧 descriptor 的固定字段”变成“adapter runtime 的稳定形状”。 + +#### 10.5.0.1 `RowInputRuntime.children` 如何读到 child 的 values/nulls + +这是 runtime / codegen 分层里最关键的一点: + +- `RowInputRuntime` **不直接暴露** `child0Values/child0Nulls/...` 这类字段; +- `RowInputRuntime` 只保存 `children[i]` —— 即 `ScalarInputRuntime*`; +- “如何从这个 child runtime 读出 value/null” 由 `RowInputAdapterCodegen` 内部持有的 scalar child readers 决定。 + +推荐读取流程: + +```cpp +auto* childRuntime = rowCodegen.loadScalarChildRuntime(codegen, rowRuntime, i); +auto* childRow = rowCodegen.scalarChildAt(i).readIRRow(codegen, childRuntime, row, slot); +auto* childValue = IRRow::getValue(builder, childRow); +auto* childIsNull = IRRow::getIsNull(builder, childRow); +``` + +也就是说: + +1. row runtime 只负责提供 `children[i]` 指针; +2. 当前阶段 child 固定为 scalar,因此不需要 runtime tag,也不需要递归 row dispatch; +3. child 的 `values/nulls/indices` 在 `ScalarInputRuntime` 内部读取; +4. 因而“读 children 的 values/nulls”不是 `RowInputRuntime` 的接口,而是 `RowInputAdapterCodegen` 调度其 scalar child readers 的结果。 + +这也解释了为什么 row runtime 本身不需要再带: + +- `indices` +- `rowField0Values` +- `rowField1Values` + +因为这些都属于 child 的读取策略,不属于 row root 的职责。 + +#### 10.5.0.2 `InputAdapter` / `InputAdapterCodegen` 应如何调整 + +为了让上面的 child-reading 成立,接口要从“单节点一次性读完”调整成“runtime root + codegen 节点协作”: + +##### 运行时对象层 + +`InputAdapter` 负责两件事: + +1. 构造并拥有 runtime payload; +2. 暴露与之匹配的 codegen 节点。 + +也就是说,`InputAdapter::runtime()` 和 `InputAdapter::codegen()` 必须成对出现。 + +##### codegen 层 + +`InputAdapterCodegen` 基类仍然只需要: + +1. `llvmValueType(...)`:告诉框架当前节点的 payload LLVM type; +2. `readIRRow(...)`:把当前 runtime 解释成 `IRRow`。 + +row-specific 的 child 访问辅助接口不必上提到基类;它们由 `RowInputAdapterCodegen` 自身私有持有即可。这样接口更贴近当前“row-of-scalars”的范围,避免过早泛化。 + +##### `RowInputAdapterCodegen` 的实现语义 + +`RowInputAdapterCodegen::readIRRow(...)` 应按下面语义生成 IR: + +1. 先检查 `rowRuntime.nulls`,得到 top-level row 是否为 null; +2. 若 top-level 为 null,返回 `IRRow{zero_payload, true}`; +3. 若 top-level 非 null,则对每个 child: + - `childRuntime = loadScalarChildRuntime(..., i)` + - `childRow = scalarChildAt(i).readIRRow(..., childRuntime, row, slot)` +4. 组装 payload: + - 若当前业务只需要 child value(例如 avg merge 的 `sum/count` 都是非 null 标量),可把 `IRRow::getValue(childRow)` 填入 payload; + - 当前阶段只支持 scalar child;若未来要支持 nested row child,再重新把 row child codegen 抽象上提。 + +因此,**从 `children` 读 values/nulls 的正确模型,不是给 `RowInputRuntime` 增字段,而是让 `RowInputAdapterCodegen` 持有一组 scalar child readers,并逐个解释 `ScalarInputRuntime`。** + +#### 10.5.1 `genAddDenseIR` 的 LLVM function 接口如何设计 + +结论先说:**可以去掉 `HashAggrJitDecodedInput`,而且 `genAddDenseIR` 的 LLVM ABI 不需要大改;最好的做法是“保留 3 参函数形状,只替换第 3 个参数的语义”**。 + +推荐接口: + +```cpp +using HashAggrJitAddDenseFunc = + void (*)(char** groups, int32_t numRows, char** inputRuntimes); +``` + +对应 LLVM: + +```llvm +define void @jit_HashAggrAddDense( + i8** %groups, + i32 %num_rows, + i8** %input_runtimes) +``` + +也就是说: + +- 参数 1:`groups`,不变; +- 参数 2:`numRows`,不变; +- 参数 3:从今天的 `decodedInputs` 改成 **`inputRuntimes`**; +- `inputRuntimes[slotIndex]` 指向该 slot 对应 InputAdapter 持有的 root runtime payload; +- JIT 函数本身**不知道也不需要知道**这是 C++ 虚对象,只把它当成 adapter-owned POD runtime 根指针来读。 + +这样做的关键收益是: + +1. `HashAggrJitChunk`、ORC JIT function pointer、调用侧大框架都几乎不用改 ABI; +2. `GroupingSet` 只需把 `hashAggrJitDecodedPtrs_` 的元素从“指向 `HashAggrJitDecodedInput`”改成“指向 adapter runtime”; +3. 热循环仍然是 `slotIndex -> load runtime ptr -> 直接 load values/indices/nulls`,不会引入 per-row virtual dispatch。 + +#### 10.5.2 为什么不建议把 LLVM 接口改成 `InputAdapter**` + +不推荐这种形状: + +```cpp +void (*)(char** groups, int32_t numRows, InputAdapter** adapters) +``` + +原因: + +1. JIT 热路径若想通过 `InputAdapter**` 做 virtual call,会直接把虚调用引进每行循环; +2. LLVM 对 C++ vtable/object layout 没有必要也不应该感知; +3. 我们真正需要的是“稳定可 load 的 runtime payload”,而不是对象本身。 + +所以正确分层应该是: + +- **C++ 对象层**:`InputAdapter` 虚接口,负责 batch 准备; +- **JIT ABI 层**:`i8** input_runtimes`,只传 POD payload 指针; +- **codegen 层**:由 slot 绑定的 adapter codegen helper 决定如何解释这个 payload。 + +#### 10.5.3 `genAddDenseIR` 内部如何按 slot 解释第 3 个参数 + +`genAddDenseIR(...)` 的 skeleton 推荐改成: + +```cpp +for each slot i: + runtime = load input_runtimes[i] + if (checkInputNulls && !countStar) { + if (slot.inputCodegen->topLevelIsNull(codegen, runtime, row)) { + goto next_slot; + } + } + addFn(codegen, group, runtime, row, slot, ...) +``` + +这里有两个重要点: + +1. **slot 用哪个 adapter codegen,是编译期常量,不是运行期分派**; +2. 第 3 参始终只是 `i8* runtime`,真正如何解释成 scalar/row runtime,由该 slot 对应的 codegen helper 完成。 + +也就是说,`HashAggrJitOps::AddFn` 仍然可以保持“每个聚合一个 add 函数”的结构,但其参数语义应从: + +```cpp +llvm::Value* decoded +``` + +改成: + +```cpp +llvm::Value* inputRuntime +``` + +然后在 `sum/min/max/avg/...` 的 addFn 内部统一写成: + +```cpp +auto* inputRow = slot.inputCodegen->readIRRow(codegen, inputRuntime, row, slot); +``` + +这样 GroupOps 看到的始终就是 `IRRow`,不再碰 `HashAggrJitDecodedInput` 的字段偏移。 + +#### 10.5.4 运行时 payload 推荐形状 + +推荐 root runtime 只保留两种: + +```cpp +struct ScalarInputRuntime { + const void* values; + const int32_t* indices; + const uint64_t* nulls; +}; + +struct RowInputRuntime { + const uint64_t* nulls; + const ScalarInputRuntime* const* children; +}; +``` + +这里刻意**不再放**: + +- `decodedVector` +- `rowField0Values` +- `rowField1Values` + +因为这些都是把框架重新绑回旧 descriptor / 特定聚合语义的回退路线。 + +avg merge / decimal sum merge 这类历史快路径,应该改为: + +- `RowInputRuntime.children[0]` 指向 field0 的 `ScalarInputRuntime` +- `RowInputRuntime.children[1]` 指向 field1 的 `ScalarInputRuntime` + +这样 JIT 仍然可以直接读 child flat raw values,并不会失去快路径。 + +#### 10.5.5 `genAddDenseIR` 的无 null 快路径怎么保留 + +这部分仍然建议保留今天的双函数模型: + +- `addDense`:会做 top-level null check +- `addDenseNoNull`:不做 top-level null check + +也就是 LLVM ABI 仍是同一个 3 参函数类型,只是生成两份实现。 + +变化点不在函数签名,而在 skeleton 里的 null 判断从: + +```cpp +loadDecodedNulls(decoded) +``` + +变成: + +```cpp +slot.inputCodegen->loadTopLevelNulls(runtime) +// or slot.inputCodegen->topLevelIsNull(...) +``` + +这样 scalar / row 输入都能复用同一套外层 skeleton,而不是把 null 逻辑重新散落到各个聚合实现里。 + +#### 10.5.6 对实现顺序的直接指导 + +因此真正落地时,`genAddDenseIR` 这条线建议按下面顺序改: + +1. 先把 `decodedInputs` 变量/注释/语义重命名为 `inputRuntimes`; +2. 把 `HashAggrJitOps::AddFn` 的 `decoded` 参数语义改成 `inputRuntime`; +3. 在 slot 上挂 compile-time 的 input codegen/helper 信息; +4. 把外层 null gating 改成走 adapter helper; +5. 最后再删 `HashAggrJitDecodedInput`、`offsetof(...)` 常量与 `loadDecoded*` 专名 API。 + +**所以答案是:能去掉,而且最合理的 `genAddDenseIR` 设计不是改成“传 adapter 对象”,而是保留 `void(i8**, i32, i8**)` 形状,把第三个参数升级成 adapter-owned runtime payload 数组。** + +#### 10.5.7 `HashAggrJitDecodedInput` 是否应该改成 union + +这个方向**是可行的,而且比“继续扩一个大 struct”更优**;在当前 bolt hash aggr JIT 这条路径里, +我现在进一步收敛为: + +- **顶层输入 runtime 可以直接用无 tag 的 union root** +- scalar / row 由 **codegen 时已知的 adapter 结构** 决定,而不是由 runtime node 自描述 +- 但 **row 的 child 当前进一步收紧为 scalar-only**,不再让 child 也走统一 union node + +推荐形状: + +```cpp +union HashAggrJitInputRuntime; + +struct HashAggrJitScalarInputRuntime { + const void* values; + const int32_t* indices; + const uint64_t* nulls; +}; + +struct HashAggrJitRowInputRuntime { + const uint64_t* nulls; + const HashAggrJitScalarInputRuntime* const* children; + int32_t numChildren; +}; + +union HashAggrJitInputRuntime { + HashAggrJitScalarInputRuntime scalar; + HashAggrJitRowInputRuntime row; +}; +``` + +##### 为什么 union 方向是对的 + +因为它解决了当前 `HashAggrJitDecodedInput` 最大的问题: + +1. **把 scalar / row 两类输入形状显式分开**,而不是塞进一个横向扩字段的大 struct; +2. `rowField0Values / rowField1Values` 这种“为了某个聚合临时开洞”的模式可以消失; +3. 第三个参数仍然可以是“runtime root 指针数组”,不影响 `genAddDenseIR` 的 3 参 ABI 形状; +4. InputAdapter 的职责能真正落到“从 union runtime 读出 `IRRow`”,而不是继续围绕旧 `DecodedInput` 的字段偏移打补丁。 + +##### 为什么这里可以省掉 shape/tag + +因为当前 add_dense 的生成方式决定了: + +1. 每个 slot 在 codegen 时已经知道输入是 scalar 还是 row; +2. 当前 row 的每个 child 固定为 scalar,child 的读取方式在 codegen 时也是已知的; +3. 热路径不需要 runtime shape dispatch,只需要按已知形状直接 load 对应字段。 + +因此,对 bolt 这条 JIT 路线而言,runtime node 的职责就是“承载值指针/nulls/children(以及 scalar 自己的 indices)”, +而不是“再告诉 JIT 自己是什么类型”。 + +##### 无 tag union 的前提条件 + +无 tag union 成立的前提是: + +1. **不能在热路径做 runtime kind 分派**; +2. slot 必须绑定 compile-time 的 input codegen/helper; +3. row child 的访问路径必须来自已知的 adapter 结构,而不是依赖 runtime 自描述; +4. 若需要 debug/assert,应由构造阶段或非 hot 校验逻辑承担,而不是把 tag 常驻在 runtime node 上。 + +也就是说:**union 是 runtime 承载方式,shape 是 codegen 元信息,不必塞进 runtime node。** + +##### 与 `genAddDenseIR` 接口的关系 + +即便采用无 tag union,`genAddDenseIR` 的 LLVM 接口也**不需要**变成复杂签名;仍建议保持: + +```cpp +void (*)(char** groups, int32_t numRows, char** inputRuntimes) +``` + +或者在 C++ typedef 层写成更强语义版本: + +```cpp +using HashAggrJitAddDenseFunc = + void (*)(char** groups, + int32_t numRows, + HashAggrJitInputRuntime* const* inputRuntimes); +``` + +但 LLVM IR 里依然可以保持 `i8**`,避免 ABI 扩散。 + +##### 什么时候 union 比“分离 root struct + void*”更优 + +我现在更偏向这个简化后的 union 方案,前提是满足下面两点: + +1. 顶层输入 runtime 统一收敛到 scalar/row 两种 root 形状; +2. RowInputRuntime 的 child 固定为 scalar runtime,由 `RowInputAdapterCodegen` 的 scalar child readers 解释成 `IRRow`。 + +因为当前实际需求只覆盖 row-of-scalars,这比“外面全是 `void*`,每层都靠约定 cast”更稳,也比提前支持递归 row 更容易落地。 + +##### 什么时候 union 仍然不够好 + +如果只是把当前这个 struct 生硬改成: + +- 一个 scalar variant +- 一个仍然带 `rowField0/rowField1` 的 row variant + +那仍然不够好,因为这只是把“avg / decimal sum 的聚合语义”从 struct 平铺变成 union variant,**没有真正建立 generic row runtime**。 + +所以采用这一版 union 的最低要求是: + +- row variant 只能有 `nulls/children/numChildren` +- `children` 必须直接指向 `ScalarInputRuntime` +- 不能再出现 `field0/field1` 这种聚合专有字段名 + +##### 结论 + +因此,对“是否可将 `HashAggrJitDecodedInput` 改成 union,并将 union 指针作为 add_dense 第三个参数传入 LLVM function”这个问题,我的结论是: + +- **可以,而且方向是对的;** +- **比继续沿用一个大而全的 struct 更优;** +- **在当前 bolt JIT 路线里,runtime node 可以不带 shape/tag;** +- **并且当前阶段 row variant 应进一步限定为 scalar-children 形状,否则实现复杂度会明显超前于需求。** + +### 10.6 为什么不会导致性能回退 + +性能保护原则: + +1. **不在每行调用虚函数**:virtual dispatch 仅用于 batch 准备; +2. **不在每行调用通用 runtime helper**:常见标量输入仍展开成直接 load `values + indices[row]` / `nulls[row]`; +3. **保留 raw child fast path**:row merge 若 child 已是 flat canonical scalar runtime,codegen 直接读取 child runtime,不退回 `DecodedVector` helper; +4. **让 encoding 差异前置到 adapter 构造**:dictionary/constant/flat 的分歧在 adapter `prepare()` 内吸收,JIT IR 只面对 canonical runtime payload; +5. **外层 add-dense skeleton 不被打散**:只替换“单 slot 如何读 input”,不引入额外 per-row 框架判断。 + +换句话说,InputAdapter 的虚接口是**对象建模边界**,不是热循环执行模型。 + +### 10.7 与当前 patch 序列的衔接 + +当前已经落下的 `IRRow`、`DecodedScalarInputAdapter`、`FlatOutputAdapter` 可以视为最终架构的前置垫片: + +- `IRRow`:保留,作为三层之间唯一值契约; +- `DecodedScalarInputAdapter`:后续升格为 `ScalarInputAdapterCodegen`,不再依附旧 `HashAggrJitDecodedInput` 命名; +- `RowOutputAdapter`:必须保持 generic,只按 struct field 写回,不认 avg 的 2-field 语义; +- `HashAggrJitCodegen::loadDecoded*`:逐步收缩为 adapter runtime 读取 helper,最终删掉 decoded-input 专名 API。 + +### 10.8 建议迁移顺序 + +#### Phase A:先把 codegen 边界改对 + +1. 把当前 `RowOutputAdapter` 改成真正 generic 的 struct writer; +2. 在 `HashAggrJit.h/.cpp` 中引入 input runtime union / adapter codegen 概念; +3. 让 avg partial、decimal merge 等 row 输入/输出先走 generic row contract,而不是 field0/field1 语义 helper。 + +#### Phase B:引入 runtime InputAdapter 对象,但暂不改 add_dense ABI + +1. `GroupingSet` 内部改为构造 `ScalarInputAdapter` / `RowInputAdapter`; +2. adapter 自己持有 runtime payload; +3. 传给 JIT 的仍可先保持 `char** inputs`,但每个元素改为指向 adapter-owned runtime,而不是 `HashAggrJitDecodedInput`。 + +这一阶段完成后,`HashAggrJitDecodedInput` 已经可以从执行路径移除,只剩个别 helper / test 兼容点。 + +#### Phase C:删除旧 descriptor 与旧命名 helper + +1. 删除 `HashAggrJitDecodedInput`; +2. 删除 `loadDecodedValue/loadDecodedNulls/loadDecodedRowField*` 这组旧 API; +3. 测试与 benchmark 一律改用 adapter 构造路径; +4. 清理 `offsetof(HashAggrJitDecodedInput, ...)` 常量与相关 runtime helper。 + +### 10.9 本章对应的 DoD 补充 + +完成 InputAdapter 重构后,应额外满足: + +- [ ] `GroupingSet` 不再直接构造 `HashAggrJitDecodedInput`; +- [ ] JIT add-dense ABI 传递的是 adapter-owned runtime payload; +- [ ] row merge / row extract 不再出现 `rowField0/rowField1` 这类聚合专有字段名; +- [ ] 新增一个 3-field intermediate 聚合时,不需要修改 InputAdapter/OutputAdapter 基类接口。 From fc440956b4bbb070031422f45964ff5023309b50 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 00:31:54 +0800 Subject: [PATCH 49/98] refactor inut and output --- bolt/exec/GroupingSet.cpp | 165 ++-- bolt/exec/GroupingSet.h | 14 +- bolt/jit/aggregation/HashAggrJit.cpp | 866 ++++++++++----------- bolt/jit/aggregation/HashAggrJit.h | 207 ++++- bolt/jit/aggregation/HashAggrJitTypes.h | 85 +- bolt/jit/aggregation/ops/AvgOps.cpp | 37 +- bolt/jit/aggregation/ops/CountOps.cpp | 26 +- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 20 +- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 18 +- bolt/jit/aggregation/ops/MinMaxOps.cpp | 14 +- bolt/jit/aggregation/ops/SumOps.cpp | 14 +- 11 files changed, 835 insertions(+), 631 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 850a9bd2f..04ab0631b 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -116,9 +116,12 @@ const void* hashAggrJitRawInputValues( return nullptr; } -void fillHashAggrJitRowFieldInputs( - jit::HashAggrJitDecodedInput& input, - const DecodedVector& decoded, +bool fillHashAggrJitRowInputRuntime( + jit::HashAggrJitInputRuntime& input, + std::vector& children, + std::vector& childPtrs, + DecodedVector& decoded, + const SelectivityVector& rows, const jit::HashAggrJitSlot& slot) { // The raw row-field fast path applies to merge inputs whose intermediate // representation is ROW(field0, field1): avg's ROW(sum, count) and decimal @@ -127,55 +130,71 @@ void fillHashAggrJitRowFieldInputs( // directly instead of calling the per-row jit_GetDecodedRowField* helper // (which rebuilds a field DecodedVector on every call). if (!slot.desc.mergeInput) { - return; + return false; } const bool isAvg = slot.desc.kind == jit::HashAggrJitKind::Avg; const bool isDecimalSum = slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal; if (!isAvg && !isDecimalSum) { - return; + return false; } const auto* base = decoded.base(); if (base == nullptr || base->encoding() != VectorEncoding::Simple::ROW) { - return; + return false; } const auto* rowVector = base->asUnchecked(); if (rowVector->childrenSize() < 2) { - return; + return false; } const auto& sumVector = rowVector->childAt(0); if (sumVector->encoding() != VectorEncoding::Simple::FLAT) { - return; + return false; } - input.rowField0Values = - hashAggrJitRawInputValues(sumVector.get(), slot.desc.inputKind); - input.rowField0Nulls = sumVector->rawNulls(); + children.resize(2); + childPtrs.resize(2); + children[0] = jit::HashAggrJitScalarInputRuntime{ + .values = hashAggrJitRawInputValues(sumVector.get(), slot.desc.inputKind), + .indices = decoded.indices(), + .nulls = sumVector->rawNulls()}; // field1 differs by aggregate: avg's count is a flat int64 scalar; decimal // sum's isEmpty is a bit-packed bool whose rawValues() is the bit-word - // buffer consumed by loadDecodedRowFieldBool's bit-read fast path. + // buffer consumed by the scalar bool bit-read fast path. const auto& field1Vector = rowVector->childAt(1); if (field1Vector->encoding() != VectorEncoding::Simple::FLAT) { - return; + return false; } if (isAvg) { - input.rowField1Values = hashAggrJitRawInputValues( - field1Vector.get(), jit::HashAggrJitValueKind::Int64); - input.rowField1Nulls = field1Vector->rawNulls(); + children[1] = jit::HashAggrJitScalarInputRuntime{ + .values = hashAggrJitRawInputValues( + field1Vector.get(), jit::HashAggrJitValueKind::Int64), + .indices = decoded.indices(), + .nulls = field1Vector->rawNulls()}; } else { // isEmpty is bit-packed bool: valuesAsVoid() exposes the underlying - // bit-word buffer (rawValues() throws for bool). loadDecodedRowFieldBool - // bit-reads it directly. - input.rowField1Values = field1Vector->valuesAsVoid(); - input.rowField1Nulls = field1Vector->rawNulls(); - } + // bit-word buffer (rawValues() throws for bool). RowInputAdapterCodegen + // bit-reads it directly via the scalar child runtime. + children[1] = jit::HashAggrJitScalarInputRuntime{ + .values = field1Vector->valuesAsVoid(), + .indices = decoded.indices(), + .nulls = field1Vector->rawNulls()}; + } + childPtrs[0] = &children[0]; + childPtrs[1] = &children[1]; + input.row = jit::HashAggrJitRowInputRuntime{ + .nulls = decoded.nulls(&rows), + .children = childPtrs.data(), + .numChildren = static_cast(children.size())}; + return true; } // Fills the raw flat sum/count field pointers for a partial avg ROW output. // Returns false when the ROW children are not both FLAT (e.g. dictionary/ // constant wrapped), in which case the JIT fast path is not applicable and the // caller must fall back to the non-JIT extract path. -bool fillHashAggrJitPartialAvgOutput( - jit::HashAggrJitOutput& output, +bool fillHashAggrJitRowOutputRuntime( + jit::HashAggrJitOutputRuntime& output, + std::vector& children, + std::vector& childPtrs, BaseVector* vector, const jit::HashAggrJitSlot& slot) { auto* rowVector = vector->asUnchecked(); @@ -188,15 +207,27 @@ bool fillHashAggrJitPartialAvgOutput( countVector->encoding() != VectorEncoding::Simple::FLAT) { return false; } - output.rowField0Values = slot.desc.decimal - ? static_cast( - sumVector->asUnchecked>()->mutableRawValues()) - : static_cast( - sumVector->asUnchecked>()->mutableRawValues()); - output.rowField0Nulls = sumVector->mutableRawNulls(); - output.rowField1Values = - countVector->asUnchecked>()->mutableRawValues(); - output.rowField1Nulls = countVector->mutableRawNulls(); + children.resize(2); + childPtrs.resize(2); + children[0] = jit::HashAggrJitScalarOutputRuntime{ + .values = slot.desc.decimal + ? static_cast( + sumVector->asUnchecked>()->mutableRawValues()) + : static_cast( + sumVector->asUnchecked>()->mutableRawValues()), + .nulls = sumVector->mutableRawNulls(), + .vector = sumVector.get()}; + children[1] = jit::HashAggrJitScalarOutputRuntime{ + .values = countVector->asUnchecked>()->mutableRawValues(), + .nulls = countVector->mutableRawNulls(), + .vector = countVector.get()}; + childPtrs[0] = &children[0]; + childPtrs[1] = &children[1]; + output.row = jit::HashAggrJitRowOutputRuntime{ + .nulls = vector->mutableRawNulls(), + .children = childPtrs.data(), + .numChildren = static_cast(children.size()), + .vector = vector}; return true; } @@ -1121,9 +1152,11 @@ void GroupingSet::runHashAggrJitChunks( const auto numSlots = chunk.slots().size(); hashAggrJitDecoded_.resize(numSlots); - hashAggrJitDecodedInputs_.resize(numSlots); + hashAggrJitInputRuntimes_.resize(numSlots); + hashAggrJitRowChildren_.resize(numSlots); + hashAggrJitRowChildPtrs_.resize(numSlots); hashAggrJitInputVectors_.assign(numSlots, nullptr); - hashAggrJitDecodedPtrs_.assign(numSlots, nullptr); + hashAggrJitInputRuntimePtrs_.assign(numSlots, nullptr); bool canRunChunk = true; std::string skipReason; @@ -1165,19 +1198,32 @@ void GroupingSet::runHashAggrJitChunks( } hashAggrJitInputVectors_[slotIndex] = arg; hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); - hashAggrJitDecodedInputs_[slotIndex] = jit::HashAggrJitDecodedInput{ - .values = hashAggrJitDecoded_[slotIndex].dataAsVoid(), - .indices = hashAggrJitDecoded_[slotIndex].indices(), - .nulls = hashAggrJitDecoded_[slotIndex].nulls(&activeRows_), - .decodedVector = &hashAggrJitDecoded_[slotIndex]}; - fillHashAggrJitRowFieldInputs( - hashAggrJitDecodedInputs_[slotIndex], - hashAggrJitDecoded_[slotIndex], - slot); + const bool usesRowInputRuntime = slot.desc.mergeInput && + (slot.desc.kind == jit::HashAggrJitKind::Avg || + (slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal)); + if (usesRowInputRuntime) { + if (!fillHashAggrJitRowInputRuntime( + hashAggrJitInputRuntimes_[slotIndex], + hashAggrJitRowChildren_[slotIndex], + hashAggrJitRowChildPtrs_[slotIndex], + hashAggrJitDecoded_[slotIndex], + activeRows_, + slot)) { + canRunChunk = false; + skipReason = "ROW input runtime requires flat scalar row children"; + break; + } + } else { + hashAggrJitInputRuntimes_[slotIndex].scalar = + jit::HashAggrJitScalarInputRuntime{ + .values = hashAggrJitDecoded_[slotIndex].dataAsVoid(), + .indices = hashAggrJitDecoded_[slotIndex].indices(), + .nulls = hashAggrJitDecoded_[slotIndex].nulls(&activeRows_)}; + } inputsMayHaveNulls = inputsMayHaveNulls || hashAggrJitDecoded_[slotIndex].mayHaveNulls(); - hashAggrJitDecodedPtrs_[slotIndex] = - reinterpret_cast(&hashAggrJitDecodedInputs_[slotIndex]); + hashAggrJitInputRuntimePtrs_[slotIndex] = + reinterpret_cast(&hashAggrJitInputRuntimes_[slotIndex]); } if (!canRunChunk) { @@ -1198,7 +1244,7 @@ void GroupingSet::runHashAggrJitChunks( chunk.addDense( groups, activeRows_.end(), - hashAggrJitDecodedPtrs_.data(), + hashAggrJitInputRuntimePtrs_.data(), inputsMayHaveNulls); for (const auto& slot : chunk.slots()) { jitExecuted[slot.aggregateIndex] = 1; @@ -1227,7 +1273,9 @@ void GroupingSet::runHashAggrJitExtractChunks( continue; } const auto numSlots = chunk.slots().size(); - hashAggrJitOutputs_.assign(numSlots, jit::HashAggrJitOutput{}); + hashAggrJitOutputRuntimes_.assign(numSlots, jit::HashAggrJitOutputRuntime{}); + hashAggrJitRowOutputChildren_.resize(numSlots); + hashAggrJitRowOutputChildPtrs_.resize(numSlots); hashAggrJitResultPtrs_.assign(numSlots, nullptr); bool canRunChunk = true; std::string skipReason; @@ -1252,25 +1300,30 @@ void GroupingSet::runHashAggrJitExtractChunks( } // Prepare stable raw output pointers after resizing. The JIT extract // function still receives char** for ABI compatibility, but each element - // now points to HashAggrJitOutput rather than BaseVector directly. + // now points to HashAggrJitOutputRuntime rather than BaseVector directly. aggregateVector->resize(groups.size()); - hashAggrJitOutputs_[slotIndex].vector = aggregateVector.get(); if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { - hashAggrJitOutputs_[slotIndex].values = - hashAggrJitRawOutputValues(aggregateVector.get(), slot.desc.accumulatorKind); - hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); + hashAggrJitOutputRuntimes_[slotIndex].scalar = + jit::HashAggrJitScalarOutputRuntime{ + .values = hashAggrJitRawOutputValues( + aggregateVector.get(), slot.desc.accumulatorKind), + .nulls = aggregateVector->mutableRawNulls(), + .vector = aggregateVector.get()}; } else if (aggregateVector->encoding() == VectorEncoding::Simple::ROW && slot.desc.kind == jit::HashAggrJitKind::Avg) { - hashAggrJitOutputs_[slotIndex].nulls = aggregateVector->mutableRawNulls(); - if (!fillHashAggrJitPartialAvgOutput( - hashAggrJitOutputs_[slotIndex], aggregateVector.get(), slot)) { + if (!fillHashAggrJitRowOutputRuntime( + hashAggrJitOutputRuntimes_[slotIndex], + hashAggrJitRowOutputChildren_[slotIndex], + hashAggrJitRowOutputChildPtrs_[slotIndex], + aggregateVector.get(), + slot)) { canRunChunk = false; skipReason = "partial avg row fields are not flat"; break; } } hashAggrJitResultPtrs_[slotIndex] = - reinterpret_cast(&hashAggrJitOutputs_[slotIndex]); + reinterpret_cast(&hashAggrJitOutputRuntimes_[slotIndex]); } if (!canRunChunk) { VLOG(1) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 7c47782e7..3296a89f7 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -464,13 +464,21 @@ class GroupingSet { #ifdef ENABLE_BOLT_JIT std::vector hashAggrJitChunks_; std::vector hashAggrJitDecoded_; - std::vector hashAggrJitDecodedInputs_; + std::vector hashAggrJitInputRuntimes_; + std::vector> + hashAggrJitRowChildren_; + std::vector> + hashAggrJitRowChildPtrs_; // Keeps input vectors alive for the DecodedVector buffers referenced by // JIT during addDense. std::vector hashAggrJitInputVectors_; - std::vector hashAggrJitDecodedPtrs_; + std::vector hashAggrJitInputRuntimePtrs_; std::vector hashAggrJitNewGroups_; - std::vector hashAggrJitOutputs_; + std::vector hashAggrJitOutputRuntimes_; + std::vector> + hashAggrJitRowOutputChildren_; + std::vector> + hashAggrJitRowOutputChildPtrs_; std::vector hashAggrJitResultPtrs_; #endif diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 930390e23..5b2e45f0c 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -19,8 +19,12 @@ extern "C" { -using bytedance::bolt::jit::HashAggrJitDecodedInput; -using bytedance::bolt::jit::HashAggrJitOutput; +using bytedance::bolt::jit::HashAggrJitInputRuntime; +using bytedance::bolt::jit::HashAggrJitOutputRuntime; +using bytedance::bolt::jit::HashAggrJitRowInputRuntime; +using bytedance::bolt::jit::HashAggrJitRowOutputRuntime; +using bytedance::bolt::jit::HashAggrJitScalarInputRuntime; +using bytedance::bolt::jit::HashAggrJitScalarOutputRuntime; namespace { @@ -50,31 +54,29 @@ void logHashAggrJitFunctionIR( << ir; } -constexpr uint64_t kDecodedInputIndicesOffset = - offsetof(HashAggrJitDecodedInput, indices); -constexpr uint64_t kDecodedInputNullsOffset = - offsetof(HashAggrJitDecodedInput, nulls); -constexpr uint64_t kDecodedInputDecodedVectorOffset = - offsetof(HashAggrJitDecodedInput, decodedVector); -constexpr uint64_t kDecodedInputFirstRowFieldOffset = - offsetof(HashAggrJitDecodedInput, rowField0Values); -constexpr uint64_t kDecodedInputRowFieldNullsOffsetDelta = - offsetof(HashAggrJitDecodedInput, rowField0Nulls) - - offsetof(HashAggrJitDecodedInput, rowField0Values); -constexpr uint64_t kDecodedInputRowFieldStride = - offsetof(HashAggrJitDecodedInput, rowField1Values) - - offsetof(HashAggrJitDecodedInput, rowField0Values); - -constexpr uint64_t kOutputNullsOffset = offsetof(HashAggrJitOutput, nulls); -constexpr uint64_t kOutputVectorOffset = offsetof(HashAggrJitOutput, vector); -constexpr uint64_t kOutputFirstRowFieldOffset = - offsetof(HashAggrJitOutput, rowField0Values); -constexpr uint64_t kOutputRowFieldNullsOffsetDelta = - offsetof(HashAggrJitOutput, rowField0Nulls) - - offsetof(HashAggrJitOutput, rowField0Values); -constexpr uint64_t kOutputRowFieldStride = - offsetof(HashAggrJitOutput, rowField1Values) - - offsetof(HashAggrJitOutput, rowField0Values); +constexpr uint64_t kScalarInputValuesOffset = + offsetof(HashAggrJitScalarInputRuntime, values); +constexpr uint64_t kScalarInputIndicesOffset = + offsetof(HashAggrJitScalarInputRuntime, indices); +constexpr uint64_t kScalarInputNullsOffset = + offsetof(HashAggrJitScalarInputRuntime, nulls); +constexpr uint64_t kRowInputNullsOffset = + offsetof(HashAggrJitRowInputRuntime, nulls); +constexpr uint64_t kRowInputChildrenOffset = + offsetof(HashAggrJitRowInputRuntime, children); + +constexpr uint64_t kScalarOutputValuesOffset = + offsetof(HashAggrJitScalarOutputRuntime, values); +constexpr uint64_t kScalarOutputNullsOffset = + offsetof(HashAggrJitScalarOutputRuntime, nulls); +constexpr uint64_t kScalarOutputVectorOffset = + offsetof(HashAggrJitScalarOutputRuntime, vector); +constexpr uint64_t kRowOutputNullsOffset = + offsetof(HashAggrJitRowOutputRuntime, nulls); +constexpr uint64_t kRowOutputChildrenOffset = + offsetof(HashAggrJitRowOutputRuntime, children); +constexpr uint64_t kRowOutputVectorOffset = + offsetof(HashAggrJitRowOutputRuntime, vector); } // namespace @@ -200,47 +202,6 @@ llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { return builder.getInt64Ty(); } -std::string decodedValueFunction(HashAggrJitValueKind kind) { - switch (kind) { - case HashAggrJitValueKind::Bool: - return "jit_GetDecodedValueBool"; - case HashAggrJitValueKind::Int8: - return "jit_GetDecodedValueI8"; - case HashAggrJitValueKind::Int16: - return "jit_GetDecodedValueI16"; - case HashAggrJitValueKind::Int32: - return "jit_GetDecodedValueI32"; - case HashAggrJitValueKind::Int64: - return "jit_GetDecodedValueI64"; - case HashAggrJitValueKind::Int128: - return "jit_GetDecodedValueI128"; - case HashAggrJitValueKind::Float: - return "jit_GetDecodedValueFloat"; - case HashAggrJitValueKind::Double: - return "jit_GetDecodedValueDouble"; - } - return "jit_GetDecodedValueI64"; -} - -std::string decodedRowFieldFunction(HashAggrJitValueKind kind) { - switch (kind) { - case HashAggrJitValueKind::Bool: - case HashAggrJitValueKind::Int8: - return "jit_GetDecodedRowFieldI8"; - case HashAggrJitValueKind::Int64: - return "jit_GetDecodedRowFieldI64"; - case HashAggrJitValueKind::Int128: - return "jit_GetDecodedRowFieldI128"; - case HashAggrJitValueKind::Double: - return "jit_GetDecodedRowFieldDouble"; - case HashAggrJitValueKind::Int16: - case HashAggrJitValueKind::Int32: - case HashAggrJitValueKind::Float: - break; - } - return ""; -} - std::string setFlatValueFunction(HashAggrJitValueKind kind); bool isFloatKind(HashAggrJitValueKind kind) { @@ -264,27 +225,38 @@ bool supportsRawFlatOutput(HashAggrJitValueKind kind) { return false; } -llvm::Value* loadOutputValues(llvm::IRBuilder<>& builder, llvm::Value* output) { +llvm::Value* loadPointerField( + llvm::IRBuilder<>& builder, + llvm::Value* descriptor, + uint64_t offset, + llvm::Type* pointerType, + llvm::StringRef name); + +llvm::Value* loadScalarOutputValues( + llvm::IRBuilder<>& builder, + llvm::Value* output) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* valuesPtrPtr = builder.CreatePointerCast(output, i8PtrTy->getPointerTo()); - return builder.CreateLoad(i8PtrTy, valuesPtrPtr, "output_values"); + return loadPointerField( + builder, output, kScalarOutputValuesOffset, i8PtrTy, "output_values"); } -llvm::Value* loadOutputNulls(llvm::IRBuilder<>& builder, llvm::Value* output) { - auto* i64Ty = builder.getInt64Ty(); - auto* nullsAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), output, kOutputNullsOffset); - auto* nullsPtrPtr = - builder.CreatePointerCast(nullsAddr, i64Ty->getPointerTo()->getPointerTo()); - return builder.CreateLoad(i64Ty->getPointerTo(), nullsPtrPtr, "output_nulls"); +llvm::Value* loadScalarOutputNulls( + llvm::IRBuilder<>& builder, + llvm::Value* output) { + return loadPointerField( + builder, + output, + kScalarOutputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "output_nulls"); } -llvm::Value* loadOutputVector(llvm::IRBuilder<>& builder, llvm::Value* output) { +llvm::Value* loadScalarOutputVector( + llvm::IRBuilder<>& builder, + llvm::Value* output) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* vectorAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), output, kOutputVectorOffset); - auto* vectorPtrPtr = builder.CreatePointerCast(vectorAddr, i8PtrTy->getPointerTo()); - return builder.CreateLoad(i8PtrTy, vectorPtrPtr, "output_vector"); + return loadPointerField( + builder, output, kScalarOutputVectorOffset, i8PtrTy, "output_vector"); } llvm::Value* loadPointerField( @@ -299,54 +271,132 @@ llvm::Value* loadPointerField( return builder.CreateLoad(pointerType, fieldPtrPtr, name); } -llvm::Value* loadDecodedIndex( +llvm::Value* loadScalarInputIndex( llvm::IRBuilder<>& builder, - llvm::Value* decoded, + llvm::Value* input, llvm::Value* row) { auto* i32Ty = builder.getInt32Ty(); auto* indices = loadPointerField( builder, - decoded, - kDecodedInputIndicesOffset, + input, + kScalarInputIndicesOffset, i32Ty->getPointerTo(), - "decoded_indices"); + "input_indices"); return builder.CreateLoad(i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); } -llvm::Value* loadDecodedRowFieldPointer( +llvm::Value* loadScalarInputValues( llvm::IRBuilder<>& builder, - llvm::Value* decoded, - int32_t field, - bool nulls) { + llvm::Value* input) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* pointerType = nulls ? builder.getInt64Ty()->getPointerTo() : i8PtrTy; - auto offset = kDecodedInputFirstRowFieldOffset + - static_cast(field) * kDecodedInputRowFieldStride + - (nulls ? kDecodedInputRowFieldNullsOffsetDelta : 0); + return loadPointerField( + builder, input, kScalarInputValuesOffset, i8PtrTy, "input_values"); +} + +llvm::Value* loadScalarInputNulls( + llvm::IRBuilder<>& builder, + llvm::Value* input) { return loadPointerField( builder, - decoded, - offset, - pointerType, - nulls ? "decoded_row_field_nulls" : "decoded_row_field_values"); + input, + kScalarInputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "input_nulls"); } -llvm::Value* loadOutputRowFieldPointer( +llvm::Value* loadRowInputNulls(llvm::IRBuilder<>& builder, llvm::Value* input) { + return loadPointerField( + builder, + input, + kRowInputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "row_input_nulls"); +} + +llvm::Value* loadRowInputChild( llvm::IRBuilder<>& builder, - llvm::Value* output, - int32_t field, - bool nulls) { + llvm::Value* input, + int32_t field) { auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* pointerType = nulls ? builder.getInt64Ty()->getPointerTo() : i8PtrTy; - auto offset = kOutputFirstRowFieldOffset + - static_cast(field) * kOutputRowFieldStride + - (nulls ? kOutputRowFieldNullsOffsetDelta : 0); + auto* children = loadPointerField( + builder, + input, + kRowInputChildrenOffset, + i8PtrTy->getPointerTo(), + "row_input_children"); + auto* childAddr = + builder.CreateConstInBoundsGEP1_64(i8PtrTy, children, field); + return builder.CreateLoad(i8PtrTy, childAddr, "row_input_child"); +} + +llvm::Value* loadScalarInputValue( + llvm::IRBuilder<>& builder, + llvm::Value* input, + llvm::Value* row, + HashAggrJitValueKind kind) { + auto* values = loadScalarInputValues(builder, input); + auto* index = loadScalarInputIndex(builder, input, row); + + if (kind == HashAggrJitValueKind::Bool) { + auto* wordTy = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(index, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(index, builder.getInt32(63)); + auto* words = builder.CreatePointerCast(values, wordTy->getPointerTo()); + auto* word = builder.CreateLoad( + wordTy, + builder.CreateInBoundsGEP( + wordTy, + words, + builder.CreateZExt(wordIndex, builder.getInt64Ty()))); + auto* shifted = + builder.CreateLShr(word, builder.CreateZExt(bitIndex, wordTy)); + return builder.CreateZExt( + builder.CreateICmpNE( + builder.CreateAnd(shifted, builder.getInt64(1)), + builder.getInt64(0)), + builder.getInt8Ty()); + } + + auto* type = llvmType(builder, kind); + auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); + auto* valueAddr = builder.CreateInBoundsGEP( + type, typedValues, builder.CreateZExt(index, builder.getInt64Ty())); + auto* load = builder.CreateLoad(type, valueAddr); + load->setAlignment(llvm::Align(1)); + return load; +} + +llvm::Value* loadRowOutputNulls(llvm::IRBuilder<>& builder, llvm::Value* output) { return loadPointerField( builder, output, - offset, - pointerType, - nulls ? "output_row_field_nulls" : "output_row_field_values"); + kRowOutputNullsOffset, + builder.getInt64Ty()->getPointerTo(), + "row_output_nulls"); +} + +llvm::Value* loadRowOutputVector( + llvm::IRBuilder<>& builder, + llvm::Value* output) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + return loadPointerField( + builder, output, kRowOutputVectorOffset, i8PtrTy, "row_output_vector"); +} + +llvm::Value* loadRowOutputChild( + llvm::IRBuilder<>& builder, + llvm::Value* output, + int32_t field) { + auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); + auto* children = loadPointerField( + builder, + output, + kRowOutputChildrenOffset, + i8PtrTy->getPointerTo(), + "row_output_children"); + auto* childAddr = + builder.CreateConstInBoundsGEP1_64(i8PtrTy, children, field); + return builder.CreateLoad(i8PtrTy, childAddr, "row_output_child"); } void emitOutputNullBit( @@ -364,7 +414,9 @@ void emitOutputNullBit( builder.getInt64(1), builder.CreateZExt(bitIndex, builder.getInt64Ty())); auto* notNullWord = builder.CreateOr(word, mask); auto* nullWord = builder.CreateAnd(word, builder.CreateNot(mask)); - auto* isNullBool = builder.CreateICmpNE(isNull, builder.getInt8(0)); + auto* isNullBool = isNull->getType()->isIntegerTy(1) + ? isNull + : builder.CreateICmpNE(isNull, builder.getInt8(0)); builder.CreateStore( builder.CreateSelect(isNullBool, nullWord, notNullWord), wordAddr); } @@ -455,62 +507,8 @@ void setAccumulatorNull( builder.CreateOr(byte, mask)); } -llvm::Value* loadDecodedValue( - llvm::IRBuilder<>& builder, - llvm::Value* decoded, - llvm::Value* row, - const HashAggrJitSlot& slot) { - auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* i32Ty = builder.getInt32Ty(); - - auto* valuesPtrPtr = builder.CreatePointerCast(decoded, i8PtrTy->getPointerTo()); - auto* values = builder.CreateLoad(i8PtrTy, valuesPtrPtr, "decoded_values"); - - auto* indicesAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), decoded, kDecodedInputIndicesOffset); - auto* indicesPtrPtr = - builder.CreatePointerCast(indicesAddr, i32Ty->getPointerTo()->getPointerTo()); - auto* indices = builder.CreateLoad(i32Ty->getPointerTo(), indicesPtrPtr, "decoded_indices"); - auto* index = builder.CreateLoad( - i32Ty, builder.CreateInBoundsGEP(i32Ty, indices, row)); - - if (slot.desc.inputKind == HashAggrJitValueKind::Bool) { - auto* wordTy = builder.getInt64Ty(); - auto* wordIndex = builder.CreateLShr(index, builder.getInt32(6)); - auto* bitIndex = builder.CreateAnd(index, builder.getInt32(63)); - auto* words = builder.CreatePointerCast(values, wordTy->getPointerTo()); - auto* word = builder.CreateLoad( - wordTy, - builder.CreateInBoundsGEP( - wordTy, words, builder.CreateZExt(wordIndex, builder.getInt64Ty()))); - auto* shifted = builder.CreateLShr(word, builder.CreateZExt(bitIndex, wordTy)); - return builder.CreateZExt( - builder.CreateICmpNE( - builder.CreateAnd(shifted, builder.getInt64(1)), builder.getInt64(0)), - builder.getInt8Ty()); - } - - auto* type = llvmType(builder, slot.desc.inputKind); - auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); - auto* valueAddr = builder.CreateInBoundsGEP( - type, typedValues, builder.CreateZExt(index, builder.getInt64Ty())); - auto* load = builder.CreateLoad(type, valueAddr); - load->setAlignment(llvm::Align(1)); - return load; -} - -llvm::Value* loadDecodedNulls(llvm::IRBuilder<>& builder, llvm::Value* decoded) { - auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* nullsAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), decoded, kDecodedInputNullsOffset); - auto* nullsPtrPtr = builder.CreatePointerCast(nullsAddr, i8PtrTy->getPointerTo()); - return builder.CreateLoad(i8PtrTy, nullsPtrPtr, "decoded_nulls"); -} - -llvm::Value* isDecodedNull( - llvm::IRBuilder<>& builder, - llvm::Value* nulls, - llvm::Value* row) { +llvm::Value* +isInputNull(llvm::IRBuilder<>& builder, llvm::Value* nulls, llvm::Value* row) { auto* i64Ty = builder.getInt64Ty(); auto* nullWords = builder.CreatePointerCast(nulls, i64Ty->getPointerTo()); auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); @@ -524,15 +522,6 @@ llvm::Value* isDecodedNull( builder.CreateAnd(shifted, builder.getInt64(1)), builder.getInt64(0)); } -llvm::Value* loadDecodedVector(llvm::IRBuilder<>& builder, llvm::Value* decoded) { - auto* i8PtrTy = llvm::PointerType::get(builder.getContext(), 0); - auto* decodedVectorAddr = builder.CreateConstInBoundsGEP1_64( - builder.getInt8Ty(), decoded, kDecodedInputDecodedVectorOffset); - auto* decodedVectorPtrPtr = - builder.CreatePointerCast(decodedVectorAddr, i8PtrTy->getPointerTo()); - return builder.CreateLoad(i8PtrTy, decodedVectorPtrPtr, "decoded_vector"); -} - } // namespace HashAggrJitCodegen::HashAggrJitCodegen(llvm::Module& module) : module_(module) { @@ -543,21 +532,10 @@ llvm::Type* HashAggrJitCodegen::llvmType(HashAggrJitValueKind kind) const { return ::bytedance::bolt::jit::llvmType(builder(), kind); } -llvm::Value* HashAggrJitCodegen::loadDecodedValue( - llvm::Value* decoded, - llvm::Value* row, - const HashAggrJitSlot& slot) const { - return ::bytedance::bolt::jit::loadDecodedValue(builder(), decoded, row, slot); -} - -llvm::Value* HashAggrJitCodegen::loadDecodedNulls(llvm::Value* decoded) const { - return ::bytedance::bolt::jit::loadDecodedNulls(builder(), decoded); -} - -llvm::Value* HashAggrJitCodegen::isDecodedNull( +llvm::Value* HashAggrJitCodegen::isInputNull( llvm::Value* nulls, llvm::Value* row) const { - return ::bytedance::bolt::jit::isDecodedNull(builder(), nulls, row); + return ::bytedance::bolt::jit::isInputNull(builder(), nulls, row); } llvm::Value* HashAggrJitCodegen::isAccumulatorNull( @@ -604,218 +582,152 @@ bool HashAggrJitCodegen::isFloatKind(HashAggrJitValueKind kind) const { return ::bytedance::bolt::jit::isFloatKind(kind); } -llvm::Value* HashAggrJitCodegen::loadDecodedRowField( - llvm::Value* decoded, +ScalarInputAdapterCodegen::ScalarInputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* input) + : codegen_(codegen), input_(input) {} + +llvm::StructType* ScalarInputAdapterCodegen::irRowType( + HashAggrJitValueKind kind) const { + return IRRow::getType(codegen_.builder(), codegen_.llvmType(kind)); +} + +llvm::Value* ScalarInputAdapterCodegen::read( llvm::Value* row, - int32_t field, HashAggrJitValueKind kind) const { - if (field == 0 || field == 1) { - auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( - builder(), decoded, field, false); - auto* hasRawValues = builder().CreateICmpNE( - rawValues, - llvm::ConstantPointerNull::get( - llvm::PointerType::get(builder().getContext(), 0))); - auto* fastBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_raw_load", - builder().GetInsertBlock()->getParent()); - auto* slowBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_helper_load", - builder().GetInsertBlock()->getParent()); - auto* doneBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_load_done", - builder().GetInsertBlock()->getParent()); - builder().CreateCondBr(hasRawValues, fastBlock, slowBlock); - - builder().SetInsertPoint(fastBlock); - auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); - auto* type = llvmType(kind); - auto* typedValues = builder().CreatePointerCast(rawValues, type->getPointerTo()); - auto* valueAddr = builder().CreateInBoundsGEP( - type, typedValues, builder().CreateZExt(index, builder().getInt64Ty())); - auto* fastValue = builder().CreateLoad(type, valueAddr); - fastValue->setAlignment(llvm::Align(1)); - builder().CreateBr(doneBlock); - auto* fastEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(slowBlock); - const auto name = decodedRowFieldFunction(kind); - BOLT_CHECK( - !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); - auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); - auto* slowValue = builder().CreateCall( - module_.getFunction(name), - {decodedVector, row, builder().getInt32(field)}); - builder().CreateBr(doneBlock); - auto* slowEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(doneBlock); - auto* value = builder().CreatePHI(llvmType(kind), 2, "row_field_value"); - value->addIncoming(fastValue, fastEnd); - value->addIncoming(slowValue, slowEnd); - return value; - } - const auto name = decodedRowFieldFunction(kind); - BOLT_CHECK( - !name.empty(), "Unsupported decoded row field kind for HashAggrJit"); - auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); - return builder().CreateCall( - module_.getFunction(name), - {decodedVector, row, builder().getInt32(field)}); + auto* value = ::bytedance::bolt::jit::loadScalarInputValue( + codegen_.builder(), input_, row, kind); + // add_dense emits the top-level null guard before invoking aggregate ops. + // Therefore rows reaching ops are non-null; keep the IRRow contract explicit + // without duplicating the null bitmap check in every aggregate. + return IRRow::pack(codegen_.builder(), value, codegen_.builder().getFalse()); +} + +llvm::Value* ScalarInputAdapterCodegen::loadNulls() const { + return ::bytedance::bolt::jit::loadScalarInputNulls( + codegen_.builder(), input_); } -llvm::Value* HashAggrJitCodegen::isDecodedRowFieldNull( - llvm::Value* decoded, +llvm::Value* ScalarInputAdapterCodegen::isNull(llvm::Value* row) const { + return codegen_.isInputNull(loadNulls(), row); +} + +llvm::Value* ScalarInputAdapterCodegen::readRowField( + llvm::Value*, + int32_t, + HashAggrJitValueKind) const { + BOLT_UNSUPPORTED("ScalarInputAdapterCodegen does not support ROW field load"); +} + +RowInputAdapterCodegen::RowInputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* input) + : codegen_(codegen), input_(input) {} + +llvm::Value* RowInputAdapterCodegen::loadChild(int32_t field) const { + return ::bytedance::bolt::jit::loadRowInputChild( + codegen_.builder(), input_, field); +} + +llvm::StructType* RowInputAdapterCodegen::irRowType( + HashAggrJitValueKind kind) const { + return IRRow::getType(codegen_.builder(), codegen_.llvmType(kind)); +} + +llvm::Value* RowInputAdapterCodegen::read(llvm::Value*, HashAggrJitValueKind) + const { + BOLT_UNSUPPORTED("RowInputAdapterCodegen does not support scalar loadValue"); +} + +llvm::Value* RowInputAdapterCodegen::loadNulls() const { + return ::bytedance::bolt::jit::loadRowInputNulls(codegen_.builder(), input_); +} + +llvm::Value* RowInputAdapterCodegen::isNull(llvm::Value* row) const { + return codegen_.isInputNull(loadNulls(), row); +} + +llvm::Value* RowInputAdapterCodegen::readRowField( llvm::Value* row, - int32_t field) const { - if (field == 0 || field == 1) { - auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( - builder(), decoded, field, false); - auto* rawNulls = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( - builder(), decoded, field, true); - auto* hasRawValues = builder().CreateICmpNE( - rawValues, - llvm::ConstantPointerNull::get( - llvm::PointerType::get(builder().getContext(), 0))); - auto* rawPathBlock = llvm::BasicBlock::Create( - module_.getContext(), "row_field_raw_null_path", builder().GetInsertBlock()->getParent()); - auto* helperPathBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_helper_null_path", - builder().GetInsertBlock()->getParent()); - auto* doneBlock = llvm::BasicBlock::Create( - module_.getContext(), "row_field_null_done", builder().GetInsertBlock()->getParent()); - builder().CreateCondBr(hasRawValues, rawPathBlock, helperPathBlock); - - builder().SetInsertPoint(rawPathBlock); - auto* hasRawNulls = builder().CreateICmpNE( - rawNulls, llvm::ConstantPointerNull::get(builder().getInt64Ty()->getPointerTo())); - auto* nullCheckBlock = llvm::BasicBlock::Create( - module_.getContext(), "row_field_null_check", builder().GetInsertBlock()->getParent()); - auto* rawDoneBlock = llvm::BasicBlock::Create( - module_.getContext(), "row_field_raw_null_done", builder().GetInsertBlock()->getParent()); - builder().CreateCondBr(hasRawNulls, nullCheckBlock, rawDoneBlock); - auto* noNullsEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(nullCheckBlock); - auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); - auto* isNull = ::bytedance::bolt::jit::isDecodedNull(builder(), rawNulls, index); - builder().CreateBr(rawDoneBlock); - auto* nullCheckEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(rawDoneBlock); - auto* fastNull = builder().CreatePHI(builder().getInt1Ty(), 2, "row_field_raw_is_null"); - fastNull->addIncoming(builder().getFalse(), noNullsEnd); - fastNull->addIncoming(isNull, nullCheckEnd); - builder().CreateBr(doneBlock); - auto* rawEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(helperPathBlock); - auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); - auto* helperNull = builder().CreateICmpNE( - builder().CreateCall( - module_.getFunction("jit_GetDecodedRowFieldIsNull"), - {decodedVector, row, builder().getInt32(field)}), - builder().getInt8(0)); - builder().CreateBr(doneBlock); - auto* helperEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(doneBlock); - auto* result = - builder().CreatePHI(builder().getInt1Ty(), 2, "row_field_is_null"); - result->addIncoming(fastNull, rawEnd); - result->addIncoming(helperNull, helperEnd); - return result; - } - auto* decodedVector = ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); - return builder().CreateICmpNE( - builder().CreateCall( - module_.getFunction("jit_GetDecodedRowFieldIsNull"), - {decodedVector, row, builder().getInt32(field)}), - builder().getInt8(0)); + int32_t field, + HashAggrJitValueKind kind) const { + auto* child = loadChild(field); + auto* value = ::bytedance::bolt::jit::loadScalarInputValue( + codegen_.builder(), child, row, kind); + return IRRow::pack(codegen_.builder(), value, isRowFieldNull(row, field)); } -llvm::Value* HashAggrJitCodegen::loadDecodedRowFieldBool( - llvm::Value* decoded, +llvm::Value* RowInputAdapterCodegen::isRowFieldNull( llvm::Value* row, int32_t field) const { - // Bool ROW fields are bit-packed, so the raw values pointer addresses an - // i64 word array indexed by bit, not a byte-per-value buffer. When the raw - // pointer is populated (merge fast path), read the bit directly; otherwise - // fall back to the helper which decodes the field per row. - auto* rawValues = ::bytedance::bolt::jit::loadDecodedRowFieldPointer( - builder(), decoded, field, false); - auto* hasRawValues = builder().CreateICmpNE( - rawValues, + auto* child = loadChild(field); + auto* nulls = + ::bytedance::bolt::jit::loadScalarInputNulls(codegen_.builder(), child); + auto* hasNulls = codegen_.builder().CreateICmpNE( + nulls, llvm::ConstantPointerNull::get( - llvm::PointerType::get(builder().getContext(), 0))); - auto* fastBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_bool_raw_load", - builder().GetInsertBlock()->getParent()); - auto* slowBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_bool_helper_load", - builder().GetInsertBlock()->getParent()); + codegen_.builder().getInt64Ty()->getPointerTo())); + auto* function = codegen_.builder().GetInsertBlock()->getParent(); + auto* nullCheckBlock = llvm::BasicBlock::Create( + codegen_.module().getContext(), "row_field_null_check", function); auto* doneBlock = llvm::BasicBlock::Create( - module_.getContext(), - "row_field_bool_load_done", - builder().GetInsertBlock()->getParent()); - builder().CreateCondBr(hasRawValues, fastBlock, slowBlock); - - builder().SetInsertPoint(fastBlock); - auto* index = ::bytedance::bolt::jit::loadDecodedIndex(builder(), decoded, row); - // bit at 'index' inside the i64 word array: word = index >> 6, bit = index & - // 63; value = (words[word] >> bit) & 1. - auto* i64Ty = builder().getInt64Ty(); - auto* words = builder().CreatePointerCast(rawValues, i64Ty->getPointerTo()); - auto* index64 = builder().CreateZExt(index, i64Ty); - auto* wordIndex = builder().CreateLShr(index64, builder().getInt64(6)); - auto* bitIndex = builder().CreateAnd(index64, builder().getInt64(63)); - auto* word = builder().CreateLoad( - i64Ty, builder().CreateInBoundsGEP(i64Ty, words, wordIndex)); - auto* shifted = builder().CreateLShr(word, bitIndex); - auto* bit = builder().CreateAnd(shifted, builder().getInt64(1)); - auto* fastValue = builder().CreateTrunc(bit, builder().getInt8Ty()); - builder().CreateBr(doneBlock); - auto* fastEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(slowBlock); - auto* decodedVector = - ::bytedance::bolt::jit::loadDecodedVector(builder(), decoded); - auto* slowValue = builder().CreateCall( - module_.getFunction("jit_GetDecodedRowFieldI8"), - {decodedVector, row, builder().getInt32(field)}); - builder().CreateBr(doneBlock); - auto* slowEnd = builder().GetInsertBlock(); - - builder().SetInsertPoint(doneBlock); - auto* value = - builder().CreatePHI(builder().getInt8Ty(), 2, "row_field_bool_value"); - value->addIncoming(fastValue, fastEnd); - value->addIncoming(slowValue, slowEnd); - return value; -} - -void HashAggrJitCodegen::emitFlatValue( - llvm::Value* output, + codegen_.module().getContext(), "row_field_null_done", function); + codegen_.builder().CreateCondBr(hasNulls, nullCheckBlock, doneBlock); + auto* noNullsEnd = codegen_.builder().GetInsertBlock(); + + codegen_.builder().SetInsertPoint(nullCheckBlock); + auto* index = ::bytedance::bolt::jit::loadScalarInputIndex( + codegen_.builder(), child, row); + auto* isNull = codegen_.isInputNull(nulls, index); + codegen_.builder().CreateBr(doneBlock); + auto* nullCheckEnd = codegen_.builder().GetInsertBlock(); + + codegen_.builder().SetInsertPoint(doneBlock); + auto* result = codegen_.builder().CreatePHI( + codegen_.builder().getInt1Ty(), 2, "row_field_is_null"); + result->addIncoming(codegen_.builder().getFalse(), noNullsEnd); + result->addIncoming(isNull, nullCheckEnd); + return result; +} + +ScalarOutputAdapterCodegen::ScalarOutputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* output) + : codegen_(codegen), output_(output) {} + +llvm::Value* ScalarOutputAdapterCodegen::vector() const { + return ::bytedance::bolt::jit::loadScalarOutputVector( + codegen_.builder(), output_); +} + +void ScalarOutputAdapterCodegen::resize(llvm::Value* size) const { + codegen_.builder().CreateCall( + codegen_.module().getFunction("jit_HashAggrResizeVector"), + {vector(), size}); +} + +void ScalarOutputAdapterCodegen::write( llvm::Value* row, HashAggrJitValueKind kind, - llvm::Value* value, - llvm::Value* isNull) const { + llvm::Value* irRow) const { + auto* value = IRRow::getValue(codegen_.builder(), irRow); + auto* isNull = IRRow::getIsNull(codegen_.builder(), irRow); if (supportsRawFlatOutput(kind)) { - auto* type = llvmType(kind); - auto* values = ::bytedance::bolt::jit::loadOutputValues(builder(), output); - auto* typedValues = builder().CreatePointerCast(values, type->getPointerTo()); - auto* valueAddr = builder().CreateInBoundsGEP( - type, typedValues, builder().CreateZExt(row, builder().getInt64Ty())); - auto* store = builder().CreateStore(value, valueAddr); + auto* type = codegen_.llvmType(kind); + auto* values = ::bytedance::bolt::jit::loadScalarOutputValues( + codegen_.builder(), output_); + auto* typedValues = codegen_.builder().CreatePointerCast( + values, type->getPointerTo()); + auto* valueAddr = codegen_.builder().CreateInBoundsGEP( + type, + typedValues, + codegen_.builder().CreateZExt(row, codegen_.builder().getInt64Ty())); + auto* store = codegen_.builder().CreateStore(value, valueAddr); store->setAlignment(llvm::Align(1)); - auto* nulls = ::bytedance::bolt::jit::loadOutputNulls(builder(), output); - ::bytedance::bolt::jit::emitOutputNullBit(builder(), nulls, row, isNull); + auto* nulls = ::bytedance::bolt::jit::loadScalarOutputNulls( + codegen_.builder(), output_); + ::bytedance::bolt::jit::emitOutputNullBit( + codegen_.builder(), nulls, row, isNull); return; } @@ -823,48 +735,89 @@ void HashAggrJitCodegen::emitFlatValue( if (setter.empty()) { return; } - auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); - builder().CreateCall( - module_.getFunction(setter), - {vector, row, value, isNull}); + auto* isNullI8 = codegen_.builder().CreateZExt( + isNull, codegen_.builder().getInt8Ty()); + codegen_.builder().CreateCall( + codegen_.module().getFunction(setter), {vector(), row, value, isNullI8}); } -void HashAggrJitCodegen::resizeResultVector( - llvm::Value* output, - llvm::Value* size) const { - auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); - builder().CreateCall( - module_.getFunction("jit_HashAggrResizeVector"), - {vector, size}); +void ScalarOutputAdapterCodegen::writeField( + llvm::Value*, + int32_t, + HashAggrJitValueKind, + llvm::Value*) const { + BOLT_UNSUPPORTED("ScalarOutputAdapterCodegen does not support ROW field write"); } -void HashAggrJitCodegen::emitPartialAvgResult( - llvm::Value* output, +void ScalarOutputAdapterCodegen::writeNull( llvm::Value* row, - llvm::Value* sum, - llvm::Value* count, llvm::Value* isNull) const { - // The extract admission path (runHashAggrJitExtractChunks) guarantees the - // partial avg ROW output has flat sum/count children before the chunk runs, - // so rowField0/1 values are always populated and we can write them directly - // without a runtime fast/helper branch. - auto* sumValues = ::bytedance::bolt::jit::loadOutputRowFieldPointer( - builder(), output, 0, false); - auto* sumTypedValues = - builder().CreatePointerCast(sumValues, builder().getDoubleTy()->getPointerTo()); - auto* countValues = ::bytedance::bolt::jit::loadOutputRowFieldPointer( - builder(), output, 1, false); - auto* countTypedValues = - builder().CreatePointerCast(countValues, builder().getInt64Ty()->getPointerTo()); - auto* row64 = builder().CreateZExt(row, builder().getInt64Ty()); - auto* sumAddr = builder().CreateInBoundsGEP(builder().getDoubleTy(), sumTypedValues, row64); - auto* sumStore = builder().CreateStore(sum, sumAddr); - sumStore->setAlignment(llvm::Align(1)); - auto* countAddr = builder().CreateInBoundsGEP(builder().getInt64Ty(), countTypedValues, row64); - auto* countStore = builder().CreateStore(count, countAddr); - countStore->setAlignment(llvm::Align(1)); - auto* nulls = ::bytedance::bolt::jit::loadOutputNulls(builder(), output); - ::bytedance::bolt::jit::emitOutputNullBit(builder(), nulls, row, isNull); + auto* nulls = ::bytedance::bolt::jit::loadScalarOutputNulls( + codegen_.builder(), output_); + ::bytedance::bolt::jit::emitOutputNullBit( + codegen_.builder(), nulls, row, isNull); +} + +RowOutputAdapterCodegen::RowOutputAdapterCodegen( + HashAggrJitCodegen& codegen, + llvm::Value* output) + : codegen_(codegen), output_(output) {} + +llvm::Value* RowOutputAdapterCodegen::loadChild(int32_t field) const { + return ::bytedance::bolt::jit::loadRowOutputChild( + codegen_.builder(), output_, field); +} + +llvm::Value* RowOutputAdapterCodegen::vector() const { + return ::bytedance::bolt::jit::loadRowOutputVector(codegen_.builder(), output_); +} + +void RowOutputAdapterCodegen::resize(llvm::Value* size) const { + codegen_.builder().CreateCall( + codegen_.module().getFunction("jit_HashAggrResizeVector"), + {vector(), size}); +} + +void RowOutputAdapterCodegen::write( + llvm::Value*, + HashAggrJitValueKind, + llvm::Value*) const { + BOLT_UNSUPPORTED("RowOutputAdapterCodegen does not support scalar write"); +} + +void RowOutputAdapterCodegen::writeField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind, + llvm::Value* irRow) const { + BOLT_CHECK( + supportsRawFlatOutput(kind), + "Unsupported raw ROW output field kind for HashAggrJit"); + auto* child = loadChild(field); + auto* value = IRRow::getValue(codegen_.builder(), irRow); + auto* isNull = IRRow::getIsNull(codegen_.builder(), irRow); + auto* type = codegen_.llvmType(kind); + auto* values = ::bytedance::bolt::jit::loadScalarOutputValues( + codegen_.builder(), child); + auto* typedValues = + codegen_.builder().CreatePointerCast(values, type->getPointerTo()); + auto* row64 = codegen_.builder().CreateZExt(row, codegen_.builder().getInt64Ty()); + auto* valueAddr = codegen_.builder().CreateInBoundsGEP(type, typedValues, row64); + auto* store = codegen_.builder().CreateStore(value, valueAddr); + store->setAlignment(llvm::Align(1)); + auto* nulls = ::bytedance::bolt::jit::loadScalarOutputNulls( + codegen_.builder(), child); + ::bytedance::bolt::jit::emitOutputNullBit( + codegen_.builder(), nulls, row, isNull); +} + +void RowOutputAdapterCodegen::writeNull( + llvm::Value* row, + llvm::Value* isNull) const { + auto* nulls = ::bytedance::bolt::jit::loadRowOutputNulls( + codegen_.builder(), output_); + ::bytedance::bolt::jit::emitOutputNullBit( + codegen_.builder(), nulls, row, isNull); } void HashAggrJitCodegen::emitDecimalSumExtract( @@ -877,10 +830,9 @@ void HashAggrJitCodegen::emitDecimalSumExtract( : "jit_HashAggrExtractFinalDecimalSum"; auto* longDecimal = builder().getInt8( slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); - auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), - {vector, + {output, row, group, builder().getInt32(slot.offset), @@ -901,10 +853,9 @@ void HashAggrJitCodegen::emitDecimalAvgExtract( slot.desc.auxPrecision > bytedance::bolt::ShortDecimalType::kMaxPrecision ? 1 : 0); - auto* vector = ::bytedance::bolt::jit::loadOutputVector(builder(), output); builder().CreateCall( module_.getFunction(fn), - {vector, + {output, row, group, builder().getInt32(slot.offset), @@ -950,6 +901,16 @@ void HashAggrJitCodegen::emitDecimalAddWithOverflow( namespace { +bool usesRowInputRuntime(const HashAggrJitSlot& slot) { + return slot.desc.mergeInput && + (slot.desc.kind == HashAggrJitKind::Avg || + (slot.desc.kind == HashAggrJitKind::Sum && slot.desc.decimal)); +} + +bool usesRowOutputRuntime(const HashAggrJitSlot& slot, bool partialOutput) { + return partialOutput && slot.desc.kind == HashAggrJitKind::Avg; +} + bool genAddDenseIR( llvm::Module& module, const std::string& fn, @@ -1027,8 +988,8 @@ bool genAddDenseIR( groups->setName("groups"); llvm::Value* numRows = &*argIt++; numRows->setName("num_rows"); - llvm::Value* decodedInputs = &*argIt++; - decodedInputs->setName("decoded_inputs"); + llvm::Value* inputRuntimes = &*argIt++; + inputRuntimes->setName("input_runtimes"); auto* entry = llvm::BasicBlock::Create(context, "entry", func); auto* loop = llvm::BasicBlock::Create(context, "loop", func); @@ -1046,18 +1007,27 @@ bool genAddDenseIR( const auto& slot = slots[i]; auto* updateBlock = llvm::BasicBlock::Create(context, "slot_update", func, end); auto* nextBlock = llvm::BasicBlock::Create(context, "slot_next", func, end); - auto* decodedAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, decodedInputs, i); - auto* decoded = builder.CreateLoad(i8PtrTy, decodedAddr); + auto* inputAddr = + builder.CreateConstInBoundsGEP1_64(i8PtrTy, inputRuntimes, i); + auto* inputRuntime = builder.CreateLoad(i8PtrTy, inputAddr); + std::unique_ptr input; + if (usesRowInputRuntime(slot)) { + input = std::make_unique(codegen, inputRuntime); + } else { + input = + std::make_unique(codegen, inputRuntime); + } if (checkInputNulls && !slot.desc.countStar) { - auto* nulls = codegen.loadDecodedNulls(decoded); + auto* nulls = input->loadNulls(); auto* nullCheckBlock = llvm::BasicBlock::Create(context, "slot_null_check", func, end); auto* hasNulls = builder.CreateICmpNE( - nulls, llvm::ConstantPointerNull::get(i8PtrTy)); + nulls, + llvm::ConstantPointerNull::get(builder.getInt64Ty()->getPointerTo())); builder.CreateCondBr(hasNulls, nullCheckBlock, updateBlock); builder.SetInsertPoint(nullCheckBlock); - auto* isNull = codegen.isDecodedNull(nulls, row); + auto* isNull = codegen.isInputNull(nulls, row); builder.CreateCondBr(isNull, nextBlock, updateBlock); } else { builder.CreateBr(updateBlock); @@ -1072,7 +1042,7 @@ bool genAddDenseIR( if (addFn == nullptr) { return false; } - addFn(codegen, group, decoded, row, slot, checkInputNulls, nextBlock); + addFn(codegen, group, *input, row, slot, checkInputNulls, nextBlock); builder.CreateBr(nextBlock); builder.SetInsertPoint(nextBlock); } @@ -1142,9 +1112,16 @@ bool genExtractIR( !slots[i].desc.ops->canExtract(slots[i], partialOutput)) { continue; } - auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); - auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); - codegen.resizeResultVector(vector, numGroups); + auto* outputAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); + auto* outputRuntime = builder.CreateLoad(i8PtrTy, outputAddr); + std::unique_ptr output; + if (usesRowOutputRuntime(slots[i], partialOutput)) { + output = std::make_unique(codegen, outputRuntime); + } else { + output = + std::make_unique(codegen, outputRuntime); + } + output->resize(numGroups); } builder.CreateCondBr(builder.CreateICmpSLE(numGroups, builder.getInt32(0)), end, loop); @@ -1160,13 +1137,20 @@ bool genExtractIR( !slot.desc.ops->canExtract(slot, partialOutput)) { continue; } - auto* vectorAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); - auto* vector = builder.CreateLoad(i8PtrTy, vectorAddr); + auto* outputAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); + auto* outputRuntime = builder.CreateLoad(i8PtrTy, outputAddr); + std::unique_ptr output; + if (usesRowOutputRuntime(slot, partialOutput)) { + output = std::make_unique(codegen, outputRuntime); + } else { + output = + std::make_unique(codegen, outputRuntime); + } if (slot.desc.ops->extract == nullptr) { return false; } slot.desc.ops->extract( - codegen, group, slot, HashAggrJitExtractTarget{vector, row, partialOutput}); + codegen, group, slot, HashAggrJitExtractTarget{*output, row, partialOutput}); } auto* next = builder.CreateAdd(row, builder.getInt32(1)); diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 264fcf9a7..c62108817 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -19,6 +19,8 @@ namespace bytedance::bolt::jit { class HashAggrJitCodegen; +class InputAdapterCodegen; +class OutputAdapterCodegen; struct HashAggrJitExtractTarget; struct HashAggrJitOps { @@ -27,7 +29,7 @@ struct HashAggrJitOps { using AddFn = void (*)( HashAggrJitCodegen&, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot&, bool checkInputNulls, @@ -48,11 +50,167 @@ struct HashAggrJitOps { }; struct HashAggrJitExtractTarget { - llvm::Value* resultVector; + const OutputAdapterCodegen& output; llvm::Value* row; bool partialOutput; }; +class IRRow { + public: + // Framework-level invariant: IRRow = {T, i1}. 'valueType' is owned by + // aggregate semantics; the null bit is always field 1. + static llvm::StructType* getType( + llvm::IRBuilder<>& builder, + llvm::Type* valueType) { + return llvm::StructType::get(valueType, builder.getInt1Ty()); + } + + static llvm::Value* getValue(llvm::IRBuilder<>& builder, llvm::Value* row) { + return builder.CreateExtractValue(row, {0}); + } + + static llvm::Value* getIsNull(llvm::IRBuilder<>& builder, llvm::Value* row) { + return builder.CreateExtractValue(row, {1}); + } + + static llvm::Value* + pack(llvm::IRBuilder<>& builder, llvm::Value* value, llvm::Value* isNull) { + auto* rowType = getType(builder, value->getType()); + auto* withValue = + builder.CreateInsertValue(llvm::UndefValue::get(rowType), value, {0}); + return builder.CreateInsertValue(withValue, isNull, {1}); + } + + static llvm::Value* + withValue(llvm::IRBuilder<>& builder, llvm::Value* row, llvm::Value* value) { + return builder.CreateInsertValue(row, value, {0}); + } + + static llvm::Value* withIsNull( + llvm::IRBuilder<>& builder, + llvm::Value* row, + llvm::Value* isNull) { + return builder.CreateInsertValue(row, isNull, {1}); + } + + // Nested value access for aggregate-owned composite payloads, e.g. + // IRRow<{double, i64}> = {{double, i64}, i1}. + static llvm::Value* + getValueField(llvm::IRBuilder<>& builder, llvm::Value* row, unsigned field) { + return builder.CreateExtractValue(row, {0, field}); + } +}; + +class InputAdapterCodegen { + public: + virtual ~InputAdapterCodegen() = default; + + virtual llvm::StructType* irRowType(HashAggrJitValueKind kind) const = 0; + virtual llvm::Value* read(llvm::Value* row, HashAggrJitValueKind kind) + const = 0; + virtual llvm::Value* loadNulls() const = 0; + virtual llvm::Value* isNull(llvm::Value* row) const = 0; + virtual llvm::Value* readRowField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const = 0; +}; + +class ScalarInputAdapterCodegen final : public InputAdapterCodegen { + public: + ScalarInputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* input); + + llvm::StructType* irRowType(HashAggrJitValueKind kind) const override; + llvm::Value* read(llvm::Value* row, HashAggrJitValueKind kind) const override; + llvm::Value* loadNulls() const override; + llvm::Value* isNull(llvm::Value* row) const override; + llvm::Value* readRowField(llvm::Value*, int32_t, HashAggrJitValueKind) + const override; + + private: + HashAggrJitCodegen& codegen_; + llvm::Value* input_; +}; + +class RowInputAdapterCodegen final : public InputAdapterCodegen { + public: + RowInputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* input); + + llvm::StructType* irRowType(HashAggrJitValueKind kind) const override; + llvm::Value* read(llvm::Value*, HashAggrJitValueKind) const override; + llvm::Value* loadNulls() const override; + llvm::Value* isNull(llvm::Value* row) const override; + llvm::Value* readRowField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const override; + + private: + llvm::Value* loadChild(int32_t field) const; + llvm::Value* isRowFieldNull(llvm::Value* row, int32_t field) const; + + HashAggrJitCodegen& codegen_; + llvm::Value* input_; +}; + +class OutputAdapterCodegen { + public: + virtual ~OutputAdapterCodegen() = default; + + virtual llvm::Value* vector() const = 0; + virtual void resize(llvm::Value* size) const = 0; + virtual void write( + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* irRow) const = 0; + virtual void writeField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind, + llvm::Value* irRow) const = 0; + virtual void writeNull(llvm::Value* row, llvm::Value* isNull) const = 0; +}; + +class ScalarOutputAdapterCodegen final : public OutputAdapterCodegen { + public: + ScalarOutputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* output); + + llvm::Value* vector() const override; + void resize(llvm::Value* size) const override; + void write( + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* irRow) const override; + void writeField(llvm::Value*, int32_t, HashAggrJitValueKind, llvm::Value*) + const override; + void writeNull(llvm::Value* row, llvm::Value* isNull) const override; + + private: + HashAggrJitCodegen& codegen_; + llvm::Value* output_; +}; + +class RowOutputAdapterCodegen final : public OutputAdapterCodegen { + public: + RowOutputAdapterCodegen(HashAggrJitCodegen& codegen, llvm::Value* output); + + llvm::Value* vector() const override; + void resize(llvm::Value* size) const override; + void write(llvm::Value*, HashAggrJitValueKind, llvm::Value*) const override; + void writeField( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind, + llvm::Value* irRow) const override; + void writeNull(llvm::Value* row, llvm::Value* isNull) const override; + + private: + llvm::Value* loadChild(int32_t field) const; + + HashAggrJitCodegen& codegen_; + llvm::Value* output_; +}; + class HashAggrJitCodegen { public: explicit HashAggrJitCodegen(llvm::Module& module); @@ -70,12 +228,7 @@ class HashAggrJitCodegen { } llvm::Type* llvmType(HashAggrJitValueKind kind) const; - llvm::Value* loadDecodedValue( - llvm::Value* decoded, - llvm::Value* row, - const HashAggrJitSlot& slot) const; - llvm::Value* loadDecodedNulls(llvm::Value* decoded) const; - llvm::Value* isDecodedNull(llvm::Value* nulls, llvm::Value* row) const; + llvm::Value* isInputNull(llvm::Value* nulls, llvm::Value* row) const; llvm::Value* isAccumulatorNull( llvm::Value* group, const HashAggrJitSlot& slot) const; @@ -95,35 +248,6 @@ class HashAggrJitCodegen { HashAggrJitValueKind from, HashAggrJitValueKind to) const; bool isFloatKind(HashAggrJitValueKind kind) const; - llvm::Value* loadDecodedRowField( - llvm::Value* decoded, - llvm::Value* row, - int32_t field, - HashAggrJitValueKind kind) const; - llvm::Value* isDecodedRowFieldNull( - llvm::Value* decoded, - llvm::Value* row, - int32_t field) const; - // Reads a bit-packed bool ROW field (e.g. decimal sum's isEmpty) as an i8 - // 0/1. The raw fast path bit-reads the flat bool buffer; falls back to the - // jit_GetDecodedRowFieldI8 helper when the field's raw pointer is unset. - llvm::Value* loadDecodedRowFieldBool( - llvm::Value* decoded, - llvm::Value* row, - int32_t field) const; - void emitFlatValue( - llvm::Value* vector, - llvm::Value* row, - HashAggrJitValueKind kind, - llvm::Value* value, - llvm::Value* isNull) const; - void resizeResultVector(llvm::Value* vector, llvm::Value* size) const; - void emitPartialAvgResult( - llvm::Value* vector, - llvm::Value* row, - llvm::Value* sum, - llvm::Value* count, - llvm::Value* isNull) const; // Decimal extract: calls a runtime helper that reads the JIT decimal // accumulator from 'group + slot.offset', applies overflow/precision checks // and writes the result (final flat decimal / partial row) into 'vector'. @@ -156,7 +280,8 @@ class HashAggrJitCodegen { llvm::IRBuilder<>* builder_{nullptr}; }; -using HashAggrJitAddDenseFunc = void (*)(char** groups, int32_t numRows, char** decodedInputs); +using HashAggrJitAddDenseFunc = + void (*)(char** groups, int32_t numRows, char** inputRuntimes); using HashAggrJitInitFunc = void (*)(char** newGroups, int32_t numNewGroups); using HashAggrJitExtractFunc = void (*)(char** groups, int32_t numGroups, char** resultVectors); @@ -181,13 +306,13 @@ class HashAggrJitChunk { void addDense( char** groups, int32_t numRows, - char** decodedInputs, + char** inputRuntimes, bool inputsMayHaveNulls) const { if (!inputsMayHaveNulls && addDenseNoNull_ != nullptr) { - addDenseNoNull_(groups, numRows, decodedInputs); + addDenseNoNull_(groups, numRows, inputRuntimes); return; } - addDense_(groups, numRows, decodedInputs); + addDense_(groups, numRows, inputRuntimes); } void extract(char** groups, int32_t numGroups, char** resultVectors) const { diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index 490407fc1..50e7e7a73 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -40,48 +40,65 @@ struct JitDecimalAvgState { int64_t overflow{0}; }; -// Runtime input descriptor consumed by JIT add_dense functions. -// GroupingSet prepares one descriptor per aggregate input for each batch by -// decoding the original vector into a flat/constant base plus a single indices -// mapping. This keeps generated IR independent of the batch's original vector -// encoding (flat/dictionary/constant) while allowing the hot loop to load -// values directly instead of calling jit_GetDecodedValue* helpers per row. -struct HashAggrJitDecodedInput { +// Runtime scalar input consumed by JIT add_dense functions. 'indices' maps the +// add_dense row to the scalar value row. The owner decides the indexing +// contract for 'nulls': top-level scalar inputs pass row-indexed nulls, while +// ROW child scalar inputs pass child/base-indexed nulls and the row adapter +// applies 'indices' before checking the bit. +struct HashAggrJitScalarInputRuntime { const void* values{nullptr}; - // Always points to a top-level-row -> base-row mapping. For flat inputs this - // is a consecutive mapping; for constant inputs it maps every row to the - // constant value index. const int32_t* indices{nullptr}; - // Top-level nulls. If non-null, bit 'row' indicates whether the input row is - // null. This is intentionally row-based rather than base-index-based to keep - // generated IR independent of dictionary/null wrapping details. const uint64_t* nulls{nullptr}; - // Original DecodedVector pointer. Kept as fallback for row-field helpers. - const void* decodedVector{nullptr}; - // Raw ROW child fields for intermediate avg merge inputs. The top-level - // ROW may still be dictionary/constant wrapped; 'indices' maps rows to the - // flat child row. Only the first two fields are needed by avg: sum, count. - const void* rowField0Values{nullptr}; - const uint64_t* rowField0Nulls{nullptr}; - const void* rowField1Values{nullptr}; - const uint64_t* rowField1Nulls{nullptr}; }; -// Runtime output descriptor consumed by JIT extract functions. GroupingSet -// prepares one descriptor per aggregate output after resizing the result vector. -// Primitive flat outputs write values/null bits directly from generated IR; -// complex outputs keep using vector helper fallbacks via 'vector'. -struct HashAggrJitOutput { +// Runtime ROW input. ROW itself has no value/indices wrapping in the generated +// IR; children are scalar runtimes. Current JIT merge inputs only require +// row-of-scalars, so recursive ROW children are intentionally not represented. +struct HashAggrJitRowInputRuntime { + const uint64_t* nulls{nullptr}; + const HashAggrJitScalarInputRuntime* const* children{nullptr}; + int32_t numChildren{0}; +}; + +// Shape-less runtime input. The generated code knows at compile time whether a +// slot reads a scalar or row input and selects the corresponding union member +// through InputAdapterCodegen. +union HashAggrJitInputRuntime { + HashAggrJitInputRuntime() : scalar{} {} + + HashAggrJitScalarInputRuntime scalar; + HashAggrJitRowInputRuntime row; +}; + +// Runtime scalar output consumed by JIT extract functions. Primitive flat +// outputs write values/null bits directly from generated IR; outputs that need +// vector semantics keep using helper fallbacks via 'vector'. +struct HashAggrJitScalarOutputRuntime { void* values{nullptr}; uint64_t* nulls{nullptr}; void* vector{nullptr}; - // Raw ROW child fields for partial avg output: field 0 = sum(double), - // field 1 = count(int64). Other outputs leave these null and use 'values' - // or helper fallback via 'vector'. - void* rowField0Values{nullptr}; - uint64_t* rowField0Nulls{nullptr}; - void* rowField1Values{nullptr}; - uint64_t* rowField1Nulls{nullptr}; +}; + +// Runtime ROW output. Current JIT partial outputs only require row-of-scalars, +// so recursive ROW children are intentionally not represented. 'vector' keeps +// the top-level vector available for helper based complex writes (e.g. decimal +// partial extract), while child scalar runtimes expose raw field buffers for +// direct generated-IR stores. +struct HashAggrJitRowOutputRuntime { + uint64_t* nulls{nullptr}; + HashAggrJitScalarOutputRuntime* const* children{nullptr}; + int32_t numChildren{0}; + void* vector{nullptr}; +}; + +// Shape-less runtime output. The generated extract code knows at compile time +// whether a slot writes a scalar or row output and selects the corresponding +// union member through OutputAdapterCodegen. +union HashAggrJitOutputRuntime { + HashAggrJitOutputRuntime() : scalar{} {} + + HashAggrJitScalarOutputRuntime scalar; + HashAggrJitRowOutputRuntime row; }; struct HashAggrJitPlanContext { diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index abdc8699d..078df4e69 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -33,12 +33,13 @@ void compileAvgInitGroup( void compileAvgAddRawInput( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* inputRow = input.read(row, slot.desc.inputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); auto* value = codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); codegen.clearAccumulatorNull(group, slot); @@ -63,16 +64,16 @@ void compileAvgAddRawInput( void compileAvgAddIntermediateResults( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { codegen.clearAccumulatorNull(group, slot); - auto* sum = - codegen.loadDecodedRowField(decoded, row, 0, HashAggrJitValueKind::Double); - auto* count = - codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int64); + auto* sumRow = input.readRowField(row, 0, HashAggrJitValueKind::Double); + auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); + auto* sum = IRRow::getValue(codegen.builder(), sumRow); + auto* count = IRRow::getValue(codegen.builder(), countRow); auto* oldSum = codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); codegen.storeValue( @@ -109,21 +110,27 @@ void compileAvgExtract( // Intermediate output is row(sum:double, count:bigint). All-null group // yields (0, 0) with a non-null top-level row (isNull = 0), matching the // non-JIT extractAccumulators path. - codegen.emitPartialAvgResult( - target.resultVector, target.row, sum, count, builder.getInt8(0)); + target.output.writeField( + target.row, + 0, + HashAggrJitValueKind::Double, + IRRow::pack(builder, sum, builder.getFalse())); + target.output.writeField( + target.row, + 1, + HashAggrJitValueKind::Int64, + IRRow::pack(builder, count, builder.getFalse())); + target.output.writeNull(target.row, builder.getFalse()); return; } // Final output is double avg. count == 0 means all inputs were null -> null. - auto* isNull = builder.CreateZExt( - builder.CreateICmpEQ(count, builder.getInt64(0)), builder.getInt8Ty()); + auto* isNull = builder.CreateICmpEQ(count, builder.getInt64(0)); auto* countAsDouble = builder.CreateSIToFP(count, builder.getDoubleTy()); auto* avg = builder.CreateFDiv(sum, countAsDouble); - codegen.emitFlatValue( - target.resultVector, + target.output.write( target.row, HashAggrJitValueKind::Double, - avg, - isNull); + IRRow::pack(builder, avg, isNull)); } } // namespace diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp index 3e586e25d..cdd8067d5 100644 --- a/bolt/jit/aggregation/ops/CountOps.cpp +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -39,7 +39,7 @@ void addInc( void compileCountAddRawInput( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* /*decoded*/, + const InputAdapterCodegen& /*input*/, llvm::Value* /*row*/, const HashAggrJitSlot& slot, bool, @@ -50,17 +50,21 @@ void compileCountAddRawInput( void compileCountAddIntermediateResults( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { - llvm::Value* inc = slot.desc.countStar - ? codegen.builder().getInt64(1) - : codegen.castValue( - codegen.loadDecodedValue(decoded, row, slot), - slot.desc.inputKind, - HashAggrJitValueKind::Int64); + llvm::Value* inc = nullptr; + if (slot.desc.countStar) { + inc = codegen.builder().getInt64(1); + } else { + auto* inputRow = input.read(row, slot.desc.inputKind); + inc = codegen.castValue( + IRRow::getValue(codegen.builder(), inputRow), + slot.desc.inputKind, + HashAggrJitValueKind::Int64); + } addInc(codegen, group, slot, inc); } @@ -76,12 +80,10 @@ void compileCountExtract( const HashAggrJitExtractTarget& target) { auto* value = codegen.loadValue(group, codegen.builder().getInt64Ty(), slot.offset); - codegen.emitFlatValue( - target.resultVector, + target.output.write( target.row, HashAggrJitValueKind::Int64, - value, - codegen.builder().getInt8(0)); + IRRow::pack(codegen.builder(), value, codegen.builder().getFalse())); } } // namespace diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index f78d32ea8..ffa90fc36 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -42,13 +42,14 @@ void compileDecimalAvgInitGroup( void compileDecimalAvgAddRawInput( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { auto& b = codegen.builder(); - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* inputRow = input.read(row, slot.desc.inputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); auto* value = codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); @@ -67,7 +68,7 @@ void compileDecimalAvgAddRawInput( void compileDecimalAvgAddIntermediateResults( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, @@ -86,10 +87,11 @@ void compileDecimalAvgAddIntermediateResults( continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "avg_decimal_merge", function, continueBlock); - auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* countIsNull = codegen.isDecodedRowFieldNull(decoded, row, 1); - auto* count = - codegen.loadDecodedRowField(decoded, row, 1, HashAggrJitValueKind::Int64); + auto* sumRow = input.readRowField(row, 0, slot.desc.inputKind); + auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); + auto* sumIsNull = IRRow::getIsNull(b, sumRow); + auto* countIsNull = IRRow::getIsNull(b, countRow); + auto* count = IRRow::getValue(b, countRow); auto* countPositive = b.CreateICmpSGT(count, b.getInt64(0)); auto* isOverflow = b.CreateAnd( sumIsNull, b.CreateAnd(b.CreateNot(countIsNull), countPositive)); @@ -100,7 +102,7 @@ void compileDecimalAvgAddIntermediateResults( b.CreateBr(continueBlock); b.SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); + auto* sum = IRRow::getValue(b, sumRow); auto* value = codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); @@ -132,7 +134,7 @@ void compileDecimalAvgExtract( const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { codegen.emitDecimalAvgExtract( - target.resultVector, target.row, group, slot, target.partialOutput); + target.output.vector(), target.row, group, slot, target.partialOutput); } } // namespace diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index cc3e8d4e0..9e6720517 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -42,13 +42,14 @@ void compileDecimalSumInitGroup( void compileDecimalSumAddRawInput( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { auto& b = codegen.builder(); - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* inputRow = input.read(row, slot.desc.inputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); auto* value = codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); @@ -61,7 +62,7 @@ void compileDecimalSumAddRawInput( void compileDecimalSumAddIntermediateResults( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, @@ -80,9 +81,10 @@ void compileDecimalSumAddIntermediateResults( continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "sum_decimal_merge", function, continueBlock); - auto* sumIsNull = codegen.isDecodedRowFieldNull(decoded, row, 0); - auto* incomingIsEmpty = - codegen.loadDecodedRowFieldBool(decoded, row, 1); + auto* sumRow = input.readRowField(row, 0, slot.desc.inputKind); + auto* isEmptyRow = input.readRowField(row, 1, HashAggrJitValueKind::Bool); + auto* sumIsNull = IRRow::getIsNull(b, sumRow); + auto* incomingIsEmpty = IRRow::getValue(b, isEmptyRow); auto* isNotEmpty = b.CreateICmpEQ(incomingIsEmpty, b.getInt8(0)); auto* isOverflow = b.CreateAnd(sumIsNull, isNotEmpty); b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); @@ -92,7 +94,7 @@ void compileDecimalSumAddIntermediateResults( b.CreateBr(continueBlock); b.SetInsertPoint(mergeBlock); - auto* sum = codegen.loadDecodedRowField(decoded, row, 0, slot.desc.inputKind); + auto* sum = IRRow::getValue(b, sumRow); auto* value = codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); @@ -124,7 +126,7 @@ void compileDecimalSumExtract( const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { codegen.emitDecimalSumExtract( - target.resultVector, target.row, group, slot, target.partialOutput); + target.output.vector(), target.row, group, slot, target.partialOutput); } } // namespace diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp index 2a39d9c0f..ed101d108 100644 --- a/bolt/jit/aggregation/ops/MinMaxOps.cpp +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -29,13 +29,14 @@ void compileMinMaxInitGroup( void compileMinMaxUpdate( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { + auto* inputRow = input.read(row, slot.desc.inputKind); auto* value = codegen.castValue( - codegen.loadDecodedValue(decoded, row, slot), + IRRow::getValue(codegen.builder(), inputRow), slot.desc.inputKind, slot.desc.accumulatorKind); auto* type = codegen.llvmType(slot.desc.accumulatorKind); @@ -85,10 +86,11 @@ void compileMinMaxExtract( const HashAggrJitExtractTarget& target) { auto* value = codegen.loadValue( group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); - auto* isNull = codegen.builder().CreateZExt( - codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); - codegen.emitFlatValue( - target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); + auto* isNull = codegen.isAccumulatorNull(group, slot); + target.output.write( + target.row, + slot.desc.accumulatorKind, + IRRow::pack(codegen.builder(), value, isNull)); } } // namespace diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp index 6324034db..e2815db4d 100644 --- a/bolt/jit/aggregation/ops/SumOps.cpp +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -31,12 +31,13 @@ void compileSumInitGroup( void compileSumAccumulate( HashAggrJitCodegen& codegen, llvm::Value* group, - llvm::Value* decoded, + const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, bool, llvm::BasicBlock*) { - auto* rawValue = codegen.loadDecodedValue(decoded, row, slot); + auto* inputRow = input.read(row, slot.desc.inputKind); + auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); auto* value = codegen.castValue( rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); auto* accType = codegen.llvmType(slot.desc.accumulatorKind); @@ -61,10 +62,11 @@ void compileSumExtract( const HashAggrJitExtractTarget& target) { auto* value = codegen.loadValue( group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); - auto* isNull = codegen.builder().CreateZExt( - codegen.isAccumulatorNull(group, slot), codegen.builder().getInt8Ty()); - codegen.emitFlatValue( - target.resultVector, target.row, slot.desc.accumulatorKind, value, isNull); + auto* isNull = codegen.isAccumulatorNull(group, slot); + target.output.write( + target.row, + slot.desc.accumulatorKind, + IRRow::pack(codegen.builder(), value, isNull)); } } // namespace From 921da8bfb3eb413b9cde61b50da82fcf0cccf7f3 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 11:59:38 +0800 Subject: [PATCH 50/98] fix style --- bolt/jit/aggregation/HashAggrJit.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 5b2e45f0c..96089b3ca 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1173,11 +1173,13 @@ HashAggrJitChunk::HashAggrJitChunk( out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" << slots_.size(); for (const auto& slot : slots_) { - out << "_" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "unknown") << "_" - << static_cast(slot.desc.kind) << hashAggrJitValueKindName(slot.desc.inputKind) - << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" << slot.offset - << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) - << (slot.desc.countStar ? "s" : "x") << (slot.desc.mergeInput ? "g" : "r") + out << "_" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "unknown") + << "_" << static_cast(slot.desc.kind) + << hashAggrJitValueKindName(slot.desc.inputKind) + << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" + << slot.offset << "n" << slot.nullByte << "m" + << static_cast(slot.nullMask) << (slot.desc.countStar ? "s" : "x") + << (slot.desc.mergeInput ? "g" : "r") << (slot.desc.decimal ? "d" : "n"); } functionName_ = out.str(); From c0619e74d1c40ca02d0a3b9c358c54ee069a17f1 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 15:22:17 +0800 Subject: [PATCH 51/98] minor refactors --- bolt/exec/GroupingSet.cpp | 27 ++----------------------- bolt/jit/aggregation/HashAggrJitTypes.h | 7 ------- bolt/jit/aggregation/ops/AvgOps.cpp | 13 ++++++++++-- 3 files changed, 13 insertions(+), 34 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 04ab0631b..8c71ff5cb 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -123,21 +123,6 @@ bool fillHashAggrJitRowInputRuntime( DecodedVector& decoded, const SelectivityVector& rows, const jit::HashAggrJitSlot& slot) { - // The raw row-field fast path applies to merge inputs whose intermediate - // representation is ROW(field0, field1): avg's ROW(sum, count) and decimal - // sum's ROW(sum, isEmpty). For both, field0 is the running sum read on the - // hot merge path; populating its raw pointer lets generated IR load it - // directly instead of calling the per-row jit_GetDecodedRowField* helper - // (which rebuilds a field DecodedVector on every call). - if (!slot.desc.mergeInput) { - return false; - } - const bool isAvg = slot.desc.kind == jit::HashAggrJitKind::Avg; - const bool isDecimalSum = - slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal; - if (!isAvg && !isDecimalSum) { - return false; - } const auto* base = decoded.base(); if (base == nullptr || base->encoding() != VectorEncoding::Simple::ROW) { return false; @@ -1288,19 +1273,11 @@ void GroupingSet::runHashAggrJitExtractChunks( skipReason = "distinct/mask/sortingKeys not supported"; break; } - auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); - const auto expectedEncoding = - (isPartial_ && slot.desc.kind == jit::HashAggrJitKind::Avg) - ? VectorEncoding::Simple::ROW - : VectorEncoding::Simple::FLAT; - if (aggregateVector->encoding() != expectedEncoding) { - canRunChunk = false; - skipReason = "unexpected result vector encoding"; - break; - } // Prepare stable raw output pointers after resizing. The JIT extract // function still receives char** for ABI compatibility, but each element // now points to HashAggrJitOutputRuntime rather than BaseVector directly. + auto& aggregateVector = + result->childAt(slot.aggregateIndex + aggregateOutputOffset); aggregateVector->resize(groups.size()); if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { hashAggrJitOutputRuntimes_[slotIndex].scalar = diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index 50e7e7a73..d9e04ac2f 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -18,13 +18,6 @@ namespace bytedance::bolt::jit { -// JIT-internal accumulator layout for avg. Shared between avg ops codegen and -// any runtime/helper logic that needs to reason about the in-row state layout. -struct JitAvgState { - double sum{0}; - int64_t count{0}; -}; - // JIT-internal accumulator layouts for decimal sum/avg. Shared between the JIT // codegen runtime helpers and the extract runtime helpers (which live in a // different translation unit and need DecimalUtil). diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index 078df4e69..bca74a945 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -7,11 +7,20 @@ #include "bolt/jit/aggregation/HashAggrJit.h" +#include + namespace bytedance::bolt::jit { namespace { -constexpr int32_t kAvgCountOffset = offsetof(JitAvgState, count); +struct AvgAccumulatorLayout { + double sum; + int64_t count; +}; + +static_assert(std::is_standard_layout_v); + +constexpr int32_t kAvgCountOffset = offsetof(AvgAccumulatorLayout, count); void compileAvgInitGroup( HashAggrJitCodegen& codegen, @@ -93,7 +102,7 @@ void compileAvgAddIntermediateResults( } bool canCompileAvgExtract(const HashAggrJitSlot& slot, bool) { - // Only double avg (JitAvgState) is supported. + // Only double avg accumulator layout is supported. return slot.desc.accumulatorKind == HashAggrJitValueKind::Double; } From b73168dc222fdb9b686b2816b721d5910c5b0b69 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 15:38:10 +0800 Subject: [PATCH 52/98] refactor hash aggr jit adapters and remove dead builtins: Replace decoded input/output plumbing with runtime adapters, localize aggregate-specific state layouts, and remove unused decoded helper and flat setter builtins from hash aggr JIT. --- bolt/exec/GroupingSet.cpp | 28 ++++++++- bolt/jit/aggregation/HashAggrJit.cpp | 61 ------------------- .../aggregation/runtime/HashAggrRuntime.cpp | 60 ------------------ 3 files changed, 26 insertions(+), 123 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 8c71ff5cb..0763fd391 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -123,6 +123,21 @@ bool fillHashAggrJitRowInputRuntime( DecodedVector& decoded, const SelectivityVector& rows, const jit::HashAggrJitSlot& slot) { + // The raw row-field fast path applies to merge inputs whose intermediate + // representation is ROW(field0, field1): avg's ROW(sum, count) and decimal + // sum's ROW(sum, isEmpty). For both, field0 is the running sum read on the + // hot merge path; populating its raw pointer lets generated IR load it + // directly instead of calling the per-row jit_GetDecodedRowField* helper + // (which rebuilds a field DecodedVector on every call). + if (!slot.desc.mergeInput) { + return false; + } + const bool isAvg = slot.desc.kind == jit::HashAggrJitKind::Avg; + const bool isDecimalSum = + slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal; + if (!isAvg && !isDecimalSum) { + return false; + } const auto* base = decoded.base(); if (base == nullptr || base->encoding() != VectorEncoding::Simple::ROW) { return false; @@ -1273,11 +1288,20 @@ void GroupingSet::runHashAggrJitExtractChunks( skipReason = "distinct/mask/sortingKeys not supported"; break; } + auto& aggregateVector = + result->childAt(slot.aggregateIndex + aggregateOutputOffset); + const auto expectedEncoding = + (isPartial_ && slot.desc.kind == jit::HashAggrJitKind::Avg) + ? VectorEncoding::Simple::ROW + : VectorEncoding::Simple::FLAT; + if (aggregateVector->encoding() != expectedEncoding) { + canRunChunk = false; + skipReason = "unexpected result vector encoding"; + break; + } // Prepare stable raw output pointers after resizing. The JIT extract // function still receives char** for ABI compatibility, but each element // now points to HashAggrJitOutputRuntime rather than BaseVector directly. - auto& aggregateVector = - result->childAt(slot.aggregateIndex + aggregateOutputOffset); aggregateVector->resize(groups.size()); if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { hashAggrJitOutputRuntimes_[slotIndex].scalar = diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 96089b3ca..40d32d79b 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -129,34 +129,7 @@ void ensureBuiltinDeclarations(llvm::Module& module) { auto* voidTy = llvm::Type::getVoidTy(context); auto* i8PtrTy = llvm::PointerType::get(context, 0); - declareFunction(module, "jit_GetDecodedValueBool", i8Ty, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_GetDecodedValueI8", i8Ty, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_GetDecodedValueI16", i16Ty, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_GetDecodedValueI32", i32Ty, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_GetDecodedValueI64", i64Ty, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_GetDecodedValueI128", i128Ty, {i8PtrTy, i32Ty}); - declareFunction( - module, "jit_GetDecodedValueFloat", floatTy, {i8PtrTy, i32Ty}); - declareFunction( - module, "jit_GetDecodedValueDouble", doubleTy, {i8PtrTy, i32Ty}); - declareFunction( - module, "jit_GetDecodedRowFieldDouble", doubleTy, {i8PtrTy, i32Ty, i32Ty}); - declareFunction( - module, "jit_GetDecodedRowFieldI8", i8Ty, {i8PtrTy, i32Ty, i32Ty}); - declareFunction( - module, "jit_GetDecodedRowFieldI64", i64Ty, {i8PtrTy, i32Ty, i32Ty}); - declareFunction( - module, "jit_GetDecodedRowFieldI128", i128Ty, {i8PtrTy, i32Ty, i32Ty}); - declareFunction( - module, "jit_GetDecodedRowFieldIsNull", i8Ty, {i8PtrTy, i32Ty, i32Ty}); - declareFunction(module, "jit_GetDecodedIsNull", i8Ty, {i8PtrTy, i32Ty}); declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); - declareFunction(module, "jit_HashAggrSetFlatI8", voidTy, {i8PtrTy, i32Ty, i8Ty, i8Ty}); - declareFunction(module, "jit_HashAggrSetFlatI16", voidTy, {i8PtrTy, i32Ty, i16Ty, i8Ty}); - declareFunction(module, "jit_HashAggrSetFlatI32", voidTy, {i8PtrTy, i32Ty, i32Ty, i8Ty}); - declareFunction(module, "jit_HashAggrSetFlatI64", voidTy, {i8PtrTy, i32Ty, i64Ty, i8Ty}); - declareFunction(module, "jit_HashAggrSetFlatFloat", voidTy, {i8PtrTy, i32Ty, floatTy, i8Ty}); - declareFunction(module, "jit_HashAggrSetFlatDouble", voidTy, {i8PtrTy, i32Ty, doubleTy, i8Ty}); // Decimal extract helpers. // Sum: (vector, row, group, offset, precision, scale, longDecimal). declareFunction( @@ -202,8 +175,6 @@ llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { return builder.getInt64Ty(); } -std::string setFlatValueFunction(HashAggrJitValueKind kind); - bool isFloatKind(HashAggrJitValueKind kind) { return kind == HashAggrJitValueKind::Float || kind == HashAggrJitValueKind::Double; @@ -730,15 +701,6 @@ void ScalarOutputAdapterCodegen::write( codegen_.builder(), nulls, row, isNull); return; } - - const auto setter = setFlatValueFunction(kind); - if (setter.empty()) { - return; - } - auto* isNullI8 = codegen_.builder().CreateZExt( - isNull, codegen_.builder().getInt8Ty()); - codegen_.builder().CreateCall( - codegen_.module().getFunction(setter), {vector(), row, value, isNullI8}); } void ScalarOutputAdapterCodegen::writeField( @@ -1057,29 +1019,6 @@ bool genAddDenseIR( return !llvm::verifyFunction(*func, &llvm::errs()); } -std::string setFlatValueFunction(HashAggrJitValueKind kind) { - switch (kind) { - case HashAggrJitValueKind::Int8: - return "jit_HashAggrSetFlatI8"; - case HashAggrJitValueKind::Int16: - return "jit_HashAggrSetFlatI16"; - case HashAggrJitValueKind::Int32: - return "jit_HashAggrSetFlatI32"; - case HashAggrJitValueKind::Int64: - return "jit_HashAggrSetFlatI64"; - case HashAggrJitValueKind::Float: - return "jit_HashAggrSetFlatFloat"; - case HashAggrJitValueKind::Double: - return "jit_HashAggrSetFlatDouble"; - // Bool output vectors are FlatVector, which cannot reuse the int8 - // setter. JIT extract is not yet supported for Bool. - case HashAggrJitValueKind::Bool: - case HashAggrJitValueKind::Int128: - return ""; - } - return ""; -} - bool genExtractIR( llvm::Module& module, const std::string& fn, diff --git a/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp index c1b5d7e3f..c25b04045 100644 --- a/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp +++ b/bolt/jit/aggregation/runtime/HashAggrRuntime.cpp @@ -20,64 +20,4 @@ __attribute__((__visibility__("default"))) void jit_HashAggrResizeVector( reinterpret_cast(vector)->resize(size); } -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI8( - char* vector, - int32_t row, - int8_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI16( - char* vector, - int32_t row, - int16_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI32( - char* vector, - int32_t row, - int32_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatI64( - char* vector, - int32_t row, - int64_t value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatFloat( - char* vector, - int32_t row, - float value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - -__attribute__((__visibility__("default"))) void jit_HashAggrSetFlatDouble( - char* vector, - int32_t row, - double value, - int8_t isNull) { - auto* flat = reinterpret_cast(vector) - ->as>(); - isNull ? flat->setNull(row, true) : flat->set(row, value); -} - } // extern "C" From f5bb503576a3a47726a68328b3d3f7e6e28272f6 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 15:40:53 +0800 Subject: [PATCH 53/98] remove uselesss declarations --- bolt/jit/aggregation/HashAggrJit.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 40d32d79b..870876226 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -120,12 +120,7 @@ llvm::FunctionCallee declareFunction( void ensureBuiltinDeclarations(llvm::Module& module) { auto& context = module.getContext(); auto* i8Ty = llvm::Type::getInt8Ty(context); - auto* i16Ty = llvm::Type::getInt16Ty(context); auto* i32Ty = llvm::Type::getInt32Ty(context); - auto* i64Ty = llvm::Type::getInt64Ty(context); - auto* i128Ty = llvm::Type::getInt128Ty(context); - auto* floatTy = llvm::Type::getFloatTy(context); - auto* doubleTy = llvm::Type::getDoubleTy(context); auto* voidTy = llvm::Type::getVoidTy(context); auto* i8PtrTy = llvm::PointerType::get(context, 0); From 2eab5b8be84f76010500607cc7dab75ea70c531d Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 16:09:50 +0800 Subject: [PATCH 54/98] push down input/output shape from framework to individual aggregate functions --- bolt/exec/GroupingSet.cpp | 16 ++++++------ .../prestosql/aggregates/CountAggregate.cpp | 2 ++ .../prestosql/aggregates/MinMaxAggregates.cpp | 2 ++ .../sparksql/aggregates/AverageAggregate.cpp | 13 ++++++++++ .../sparksql/aggregates/DecimalSumAggregate.h | 3 +++ .../sparksql/aggregates/SumAggregate.cpp | 2 ++ bolt/jit/aggregation/HashAggrJit.cpp | 26 ++++++++++++++----- bolt/jit/aggregation/HashAggrJitTypes.h | 7 +++++ 8 files changed, 56 insertions(+), 15 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 0763fd391..8b7cf6e59 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -129,7 +129,7 @@ bool fillHashAggrJitRowInputRuntime( // hot merge path; populating its raw pointer lets generated IR load it // directly instead of calling the per-row jit_GetDecodedRowField* helper // (which rebuilds a field DecodedVector on every call). - if (!slot.desc.mergeInput) { + if (slot.desc.inputShape != jit::HashAggrJitRuntimeShape::Row) { return false; } const bool isAvg = slot.desc.kind == jit::HashAggrJitKind::Avg; @@ -1198,9 +1198,8 @@ void GroupingSet::runHashAggrJitChunks( } hashAggrJitInputVectors_[slotIndex] = arg; hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); - const bool usesRowInputRuntime = slot.desc.mergeInput && - (slot.desc.kind == jit::HashAggrJitKind::Avg || - (slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal)); + const bool usesRowInputRuntime = + slot.desc.inputShape == jit::HashAggrJitRuntimeShape::Row; if (usesRowInputRuntime) { if (!fillHashAggrJitRowInputRuntime( hashAggrJitInputRuntimes_[slotIndex], @@ -1291,7 +1290,7 @@ void GroupingSet::runHashAggrJitExtractChunks( auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); const auto expectedEncoding = - (isPartial_ && slot.desc.kind == jit::HashAggrJitKind::Avg) + slot.desc.outputShape == jit::HashAggrJitRuntimeShape::Row ? VectorEncoding::Simple::ROW : VectorEncoding::Simple::FLAT; if (aggregateVector->encoding() != expectedEncoding) { @@ -1310,8 +1309,9 @@ void GroupingSet::runHashAggrJitExtractChunks( aggregateVector.get(), slot.desc.accumulatorKind), .nulls = aggregateVector->mutableRawNulls(), .vector = aggregateVector.get()}; - } else if (aggregateVector->encoding() == VectorEncoding::Simple::ROW && - slot.desc.kind == jit::HashAggrJitKind::Avg) { + } else if ( + aggregateVector->encoding() == VectorEncoding::Simple::ROW && + slot.desc.outputShape == jit::HashAggrJitRuntimeShape::Row) { if (!fillHashAggrJitRowOutputRuntime( hashAggrJitOutputRuntimes_[slotIndex], hashAggrJitRowOutputChildren_[slotIndex], @@ -1319,7 +1319,7 @@ void GroupingSet::runHashAggrJitExtractChunks( aggregateVector.get(), slot)) { canRunChunk = false; - skipReason = "partial avg row fields are not flat"; + skipReason = "ROW output runtime requires flat scalar row children"; break; } } diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 78fcc657a..34b533e8d 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -87,6 +87,8 @@ class CountAggregate : public SimpleNumericAggregate { .countStar = context.isCountStar(), .mergeInput = !context.isRawInput, .decimal = false, + .inputShape = jit::HashAggrJitRuntimeShape::Scalar, + .outputShape = jit::HashAggrJitRuntimeShape::Scalar, .precision = 0, .scale = 0, .auxPrecision = 0, diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 868357cd7..364960d4d 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -78,6 +78,8 @@ class MinMaxAggregate : public SimpleNumericAggregate { .countStar = false, .mergeInput = !context.isRawInput, .decimal = false, + .inputShape = jit::HashAggrJitRuntimeShape::Scalar, + .outputShape = jit::HashAggrJitRuntimeShape::Scalar, .precision = 0, .scale = 0, .auxPrecision = 0, diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 70e46a962..344e48315 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -75,6 +75,10 @@ class AverageAggregate .countStar = false, .mergeInput = true, .decimal = false, + .inputShape = jit::HashAggrJitRuntimeShape::Row, + .outputShape = context.isPartialOutput + ? jit::HashAggrJitRuntimeShape::Row + : jit::HashAggrJitRuntimeShape::Scalar, .precision = 0, .scale = 0, .auxPrecision = 0, @@ -93,6 +97,10 @@ class AverageAggregate .countStar = false, .mergeInput = false, .decimal = false, + .inputShape = jit::HashAggrJitRuntimeShape::Scalar, + .outputShape = context.isPartialOutput + ? jit::HashAggrJitRuntimeShape::Row + : jit::HashAggrJitRuntimeShape::Scalar, .precision = 0, .scale = 0, .auxPrecision = 0, @@ -189,6 +197,11 @@ class DecimalAverageAggregate : public DecimalAggregate { .countStar = false, .mergeInput = !context.isRawInput, .decimal = true, + .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar + : jit::HashAggrJitRuntimeShape::Row, + .outputShape = context.isPartialOutput + ? jit::HashAggrJitRuntimeShape::Row + : jit::HashAggrJitRuntimeShape::Scalar, .precision = sumPrecision, .scale = sumScale, .auxPrecision = resultPrecision, diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 4a2ed825c..fffe34e9a 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -95,6 +95,9 @@ class DecimalSumAggregate : public exec::Aggregate { .countStar = false, .mergeInput = !context.isRawInput, .decimal = true, + .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar + : jit::HashAggrJitRuntimeShape::Row, + .outputShape = jit::HashAggrJitRuntimeShape::Scalar, .precision = resultPrecision, .scale = resultScale, .auxPrecision = 0, diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index fe76bf2f4..35b3a3950 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -80,6 +80,8 @@ class SumAggregate : public SumAggregateBase { .countStar = false, .mergeInput = !context.isRawInput, .decimal = false, + .inputShape = jit::HashAggrJitRuntimeShape::Scalar, + .outputShape = jit::HashAggrJitRuntimeShape::Scalar, .precision = 0, .scale = 0, .auxPrecision = 0, diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 870876226..08b882fcd 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -858,14 +858,22 @@ void HashAggrJitCodegen::emitDecimalAddWithOverflow( namespace { +char hashAggrJitRuntimeShapeName(HashAggrJitRuntimeShape shape) { + switch (shape) { + case HashAggrJitRuntimeShape::Scalar: + return 's'; + case HashAggrJitRuntimeShape::Row: + return 'r'; + } + return 'u'; +} + bool usesRowInputRuntime(const HashAggrJitSlot& slot) { - return slot.desc.mergeInput && - (slot.desc.kind == HashAggrJitKind::Avg || - (slot.desc.kind == HashAggrJitKind::Sum && slot.desc.decimal)); + return slot.desc.inputShape == HashAggrJitRuntimeShape::Row; } bool usesRowOutputRuntime(const HashAggrJitSlot& slot, bool partialOutput) { - return partialOutput && slot.desc.kind == HashAggrJitKind::Avg; + return partialOutput && slot.desc.outputShape == HashAggrJitRuntimeShape::Row; } bool genAddDenseIR( @@ -1114,7 +1122,9 @@ HashAggrJitChunk::HashAggrJitChunk( << slot.offset << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) << (slot.desc.countStar ? "s" : "x") << (slot.desc.mergeInput ? "g" : "r") - << (slot.desc.decimal ? "d" : "n"); + << (slot.desc.decimal ? "d" : "n") << "i" + << hashAggrJitRuntimeShapeName(slot.desc.inputShape) << "o" + << hashAggrJitRuntimeShapeName(slot.desc.outputShape); } functionName_ = out.str(); initFunctionName_ = functionName_ + "_init"; @@ -1185,13 +1195,15 @@ bool isHashAggrJitSupportedType(TypeKind kind) { std::string HashAggrJitDescriptor::signature() const { return fmt::format( - "{}_{}_{}_{}_{}_{}", + "{}_{}_{}_{}_{}_{}_{}_{}", ops != nullptr ? ops->id : "unknown", static_cast(kind), hashAggrJitValueKindName(inputKind), hashAggrJitValueKindName(accumulatorKind), mergeInput, - decimal); + decimal, + hashAggrJitRuntimeShapeName(inputShape), + hashAggrJitRuntimeShapeName(outputShape)); } bool HashAggrJitChunk::canExtract() const { diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index d9e04ac2f..b053bfcdb 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -124,6 +124,11 @@ enum class HashAggrJitValueKind : uint8_t { Double, }; +enum class HashAggrJitRuntimeShape : uint8_t { + Scalar, + Row, +}; + // Forward declaration: the codegen function-pointer table is defined in // HashAggrJit.h (it references llvm:: types). Descriptors only hold a pointer // to it, so a forward declaration is enough here and keeps this header @@ -137,6 +142,8 @@ struct HashAggrJitDescriptor { bool countStar{false}; bool mergeInput{false}; bool decimal{false}; + HashAggrJitRuntimeShape inputShape{HashAggrJitRuntimeShape::Scalar}; + HashAggrJitRuntimeShape outputShape{HashAggrJitRuntimeShape::Scalar}; // Result decimal precision/scale, used by decimal extract overflow checks. // Only meaningful when decimal == true. int32_t precision{0}; From 6ef9318db3178e030c9bc4ca2251567d6a767871 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 16:36:13 +0800 Subject: [PATCH 55/98] decouple hash aggr jit row runtime binding from aggregate kinds --- bolt/exec/GroupingSet.cpp | 153 +++++++++++++++++++------------------- 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 8b7cf6e59..bb21069a2 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -69,10 +69,24 @@ std::string hashAggrJitTypeName(const TypePtr& type) { return type == nullptr ? "null" : type->toString(); } -void* hashAggrJitRawOutputValues( +std::optional hashAggrJitOutputValueKind( + const BaseVector* vector) { + const auto& type = vector->type(); + if (type->isShortDecimal()) { + return jit::HashAggrJitValueKind::Int64; + } + if (type->isLongDecimal()) { + return jit::HashAggrJitValueKind::Int128; + } + return jit::hashAggrJitValueKind(type->kind()); +} + +void* hashAggrJitRawOutputData( BaseVector* vector, jit::HashAggrJitValueKind kind) { switch (kind) { + case jit::HashAggrJitValueKind::Bool: + return const_cast(vector->valuesAsVoid()); case jit::HashAggrJitValueKind::Int8: return vector->asUnchecked>()->mutableRawValues(); case jit::HashAggrJitValueKind::Int16: @@ -85,17 +99,30 @@ void* hashAggrJitRawOutputValues( return vector->asUnchecked>()->mutableRawValues(); case jit::HashAggrJitValueKind::Double: return vector->asUnchecked>()->mutableRawValues(); - case jit::HashAggrJitValueKind::Bool: case jit::HashAggrJitValueKind::Int128: - return nullptr; + return vector->asUnchecked>()->mutableRawValues(); } return nullptr; } -const void* hashAggrJitRawInputValues( +std::optional hashAggrJitInputValueKind( + const BaseVector* vector) { + const auto& type = vector->type(); + if (type->isShortDecimal()) { + return jit::HashAggrJitValueKind::Int64; + } + if (type->isLongDecimal()) { + return jit::HashAggrJitValueKind::Int128; + } + return jit::hashAggrJitValueKind(type->kind()); +} + +const void* hashAggrJitRawInputData( const BaseVector* vector, jit::HashAggrJitValueKind kind) { switch (kind) { + case jit::HashAggrJitValueKind::Bool: + return vector->valuesAsVoid(); case jit::HashAggrJitValueKind::Int8: return vector->asUnchecked>()->rawValues(); case jit::HashAggrJitValueKind::Int16: @@ -110,8 +137,6 @@ const void* hashAggrJitRawInputValues( return vector->asUnchecked>()->rawValues(); case jit::HashAggrJitValueKind::Double: return vector->asUnchecked>()->rawValues(); - case jit::HashAggrJitValueKind::Bool: - return nullptr; } return nullptr; } @@ -132,101 +157,74 @@ bool fillHashAggrJitRowInputRuntime( if (slot.desc.inputShape != jit::HashAggrJitRuntimeShape::Row) { return false; } - const bool isAvg = slot.desc.kind == jit::HashAggrJitKind::Avg; - const bool isDecimalSum = - slot.desc.kind == jit::HashAggrJitKind::Sum && slot.desc.decimal; - if (!isAvg && !isDecimalSum) { - return false; - } const auto* base = decoded.base(); if (base == nullptr || base->encoding() != VectorEncoding::Simple::ROW) { return false; } const auto* rowVector = base->asUnchecked(); - if (rowVector->childrenSize() < 2) { - return false; - } - const auto& sumVector = rowVector->childAt(0); - if (sumVector->encoding() != VectorEncoding::Simple::FLAT) { + if (rowVector->childrenSize() == 0) { return false; } - children.resize(2); - childPtrs.resize(2); - children[0] = jit::HashAggrJitScalarInputRuntime{ - .values = hashAggrJitRawInputValues(sumVector.get(), slot.desc.inputKind), - .indices = decoded.indices(), - .nulls = sumVector->rawNulls()}; - // field1 differs by aggregate: avg's count is a flat int64 scalar; decimal - // sum's isEmpty is a bit-packed bool whose rawValues() is the bit-word - // buffer consumed by the scalar bool bit-read fast path. - const auto& field1Vector = rowVector->childAt(1); - if (field1Vector->encoding() != VectorEncoding::Simple::FLAT) { - return false; - } - if (isAvg) { - children[1] = jit::HashAggrJitScalarInputRuntime{ - .values = hashAggrJitRawInputValues( - field1Vector.get(), jit::HashAggrJitValueKind::Int64), - .indices = decoded.indices(), - .nulls = field1Vector->rawNulls()}; - } else { - // isEmpty is bit-packed bool: valuesAsVoid() exposes the underlying - // bit-word buffer (rawValues() throws for bool). RowInputAdapterCodegen - // bit-reads it directly via the scalar child runtime. - children[1] = jit::HashAggrJitScalarInputRuntime{ - .values = field1Vector->valuesAsVoid(), + const auto numChildren = rowVector->childrenSize(); + children.resize(numChildren); + childPtrs.resize(numChildren); + for (auto field = 0; field < numChildren; ++field) { + const auto& childVector = rowVector->childAt(field); + if (childVector->encoding() != VectorEncoding::Simple::FLAT) { + return false; + } + const auto kind = hashAggrJitInputValueKind(childVector.get()); + if (!kind.has_value()) { + return false; + } + children[field] = jit::HashAggrJitScalarInputRuntime{ + .values = hashAggrJitRawInputData(childVector.get(), *kind), .indices = decoded.indices(), - .nulls = field1Vector->rawNulls()}; + .nulls = childVector->rawNulls()}; + childPtrs[field] = &children[field]; } - childPtrs[0] = &children[0]; - childPtrs[1] = &children[1]; input.row = jit::HashAggrJitRowInputRuntime{ .nulls = decoded.nulls(&rows), .children = childPtrs.data(), - .numChildren = static_cast(children.size())}; + .numChildren = static_cast(numChildren)}; return true; } -// Fills the raw flat sum/count field pointers for a partial avg ROW output. -// Returns false when the ROW children are not both FLAT (e.g. dictionary/ -// constant wrapped), in which case the JIT fast path is not applicable and the -// caller must fall back to the non-JIT extract path. +// Fills the raw child pointers for a ROW output runtime. Returns false when any +// ROW child is not FLAT (e.g. dictionary/constant wrapped), in which case the +// JIT fast path is not applicable and the caller must fall back to the non-JIT +// extract path. bool fillHashAggrJitRowOutputRuntime( jit::HashAggrJitOutputRuntime& output, std::vector& children, std::vector& childPtrs, - BaseVector* vector, - const jit::HashAggrJitSlot& slot) { + BaseVector* vector) { auto* rowVector = vector->asUnchecked(); - if (rowVector->childrenSize() < 2) { + if (rowVector->childrenSize() == 0) { return false; } - auto& sumVector = rowVector->childAt(0); - auto& countVector = rowVector->childAt(1); - if (sumVector->encoding() != VectorEncoding::Simple::FLAT || - countVector->encoding() != VectorEncoding::Simple::FLAT) { - return false; + const auto numChildren = rowVector->childrenSize(); + children.resize(numChildren); + childPtrs.resize(numChildren); + for (auto field = 0; field < numChildren; ++field) { + auto& childVector = rowVector->childAt(field); + if (childVector->encoding() != VectorEncoding::Simple::FLAT) { + return false; + } + const auto kind = hashAggrJitOutputValueKind(childVector.get()); + if (!kind.has_value()) { + return false; + } + children[field] = jit::HashAggrJitScalarOutputRuntime{ + .values = hashAggrJitRawOutputData(childVector.get(), *kind), + .nulls = childVector->mutableRawNulls(), + .vector = childVector.get()}; + childPtrs[field] = &children[field]; } - children.resize(2); - childPtrs.resize(2); - children[0] = jit::HashAggrJitScalarOutputRuntime{ - .values = slot.desc.decimal - ? static_cast( - sumVector->asUnchecked>()->mutableRawValues()) - : static_cast( - sumVector->asUnchecked>()->mutableRawValues()), - .nulls = sumVector->mutableRawNulls(), - .vector = sumVector.get()}; - children[1] = jit::HashAggrJitScalarOutputRuntime{ - .values = countVector->asUnchecked>()->mutableRawValues(), - .nulls = countVector->mutableRawNulls(), - .vector = countVector.get()}; - childPtrs[0] = &children[0]; - childPtrs[1] = &children[1]; output.row = jit::HashAggrJitRowOutputRuntime{ .nulls = vector->mutableRawNulls(), .children = childPtrs.data(), - .numChildren = static_cast(children.size()), + .numChildren = static_cast(numChildren), .vector = vector}; return true; } @@ -1305,7 +1303,7 @@ void GroupingSet::runHashAggrJitExtractChunks( if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { hashAggrJitOutputRuntimes_[slotIndex].scalar = jit::HashAggrJitScalarOutputRuntime{ - .values = hashAggrJitRawOutputValues( + .values = hashAggrJitRawOutputData( aggregateVector.get(), slot.desc.accumulatorKind), .nulls = aggregateVector->mutableRawNulls(), .vector = aggregateVector.get()}; @@ -1316,8 +1314,7 @@ void GroupingSet::runHashAggrJitExtractChunks( hashAggrJitOutputRuntimes_[slotIndex], hashAggrJitRowOutputChildren_[slotIndex], hashAggrJitRowOutputChildPtrs_[slotIndex], - aggregateVector.get(), - slot)) { + aggregateVector.get())) { canRunChunk = false; skipReason = "ROW output runtime requires flat scalar row children"; break; From 05c4b1739eaaffa3dfc64db465447547cf595e41 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 17:00:35 +0800 Subject: [PATCH 56/98] decouple JitDecimalSumState and JitDecimalAvgState --- .../jit/aggregation/HashAggrJitDecimalState.h | 33 +++++++++++++++++++ bolt/jit/aggregation/HashAggrJitTypes.h | 16 --------- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 1 + bolt/jit/aggregation/ops/DecimalSumOps.cpp | 1 + .../runtime/HashAggrDecimalRuntime.cpp | 2 +- 5 files changed, 36 insertions(+), 17 deletions(-) create mode 100644 bolt/jit/aggregation/HashAggrJitDecimalState.h diff --git a/bolt/jit/aggregation/HashAggrJitDecimalState.h b/bolt/jit/aggregation/HashAggrJitDecimalState.h new file mode 100644 index 000000000..98a16c3d7 --- /dev/null +++ b/bolt/jit/aggregation/HashAggrJitDecimalState.h @@ -0,0 +1,33 @@ +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include +#include + +#include "bolt/type/Type.h" + +namespace bytedance::bolt::jit { + +// JIT-internal decimal accumulator layouts shared only by decimal aggregate +// codegen and decimal extract runtime helpers. Keep them out of the framework +// planning/types header so non-decimal ops don't depend on aggregate-private +// row state details. +struct JitDecimalSumState { + bytedance::bolt::int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; +}; + +struct JitDecimalAvgState { + bytedance::bolt::int128_t sum{0}; + int64_t count{0}; + int64_t overflow{0}; +}; + +static_assert(std::is_standard_layout_v); +static_assert(std::is_standard_layout_v); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index b053bfcdb..b4d0fabed 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -2,7 +2,6 @@ #ifdef ENABLE_BOLT_JIT -#include #include #include #include @@ -18,21 +17,6 @@ namespace bytedance::bolt::jit { -// JIT-internal accumulator layouts for decimal sum/avg. Shared between the JIT -// codegen runtime helpers and the extract runtime helpers (which live in a -// different translation unit and need DecimalUtil). -struct JitDecimalSumState { - bytedance::bolt::int128_t sum{0}; - int64_t overflow{0}; - bool isEmpty{true}; -}; - -struct JitDecimalAvgState { - bytedance::bolt::int128_t sum{0}; - int64_t count{0}; - int64_t overflow{0}; -}; - // Runtime scalar input consumed by JIT add_dense functions. 'indices' maps the // add_dense row to the scalar value row. The owner decides the indexing // contract for 'nulls': top-level scalar inputs pass row-indexed nulls, while diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index ffa90fc36..274e67f07 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -8,6 +8,7 @@ #include #include "bolt/jit/aggregation/HashAggrJit.h" +#include "bolt/jit/aggregation/HashAggrJitDecimalState.h" namespace bytedance::bolt::jit { diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 9e6720517..38dff3963 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -8,6 +8,7 @@ #include #include "bolt/jit/aggregation/HashAggrJit.h" +#include "bolt/jit/aggregation/HashAggrJitDecimalState.h" namespace bytedance::bolt::jit { diff --git a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp index f18fff1d9..9d115d6f3 100644 --- a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp +++ b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp @@ -17,7 +17,7 @@ #include #include "bolt/functions/sparksql/DecimalUtil.h" -#include "bolt/jit/aggregation/HashAggrJitTypes.h" +#include "bolt/jit/aggregation/HashAggrJitDecimalState.h" #include "bolt/type/DecimalUtil.h" #include "bolt/vector/ComplexVector.h" #include "bolt/vector/FlatVector.h" From 62fbb70f741890c3196354fe34ceb2701c358265 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Fri, 12 Jun 2026 18:01:27 +0800 Subject: [PATCH 57/98] remove uncessary codes --- bolt/jit/aggregation/HashAggrJit.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 08b882fcd..67d43f6d7 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1049,22 +1049,6 @@ bool genExtractIR( auto* loop = llvm::BasicBlock::Create(context, "loop", func); auto* end = llvm::BasicBlock::Create(context, "end", func); builder.SetInsertPoint(entry); - for (auto i = 0; i < slots.size(); ++i) { - if (slots[i].desc.ops == nullptr || slots[i].desc.ops->canExtract == nullptr || - !slots[i].desc.ops->canExtract(slots[i], partialOutput)) { - continue; - } - auto* outputAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); - auto* outputRuntime = builder.CreateLoad(i8PtrTy, outputAddr); - std::unique_ptr output; - if (usesRowOutputRuntime(slots[i], partialOutput)) { - output = std::make_unique(codegen, outputRuntime); - } else { - output = - std::make_unique(codegen, outputRuntime); - } - output->resize(numGroups); - } builder.CreateCondBr(builder.CreateICmpSLE(numGroups, builder.getInt32(0)), end, loop); builder.SetInsertPoint(loop); From 0af777ab24adb1d828ed6af99f921974f269070f Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 10:52:29 +0800 Subject: [PATCH 58/98] fix benchmark crash caused by inconsistent accumulator kind and actual result vector --- bolt/exec/GroupingSet.cpp | 14 +++++- hashaggr_jit_refactor_plan.md | 92 +++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index bb21069a2..175d95432 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1301,10 +1301,22 @@ void GroupingSet::runHashAggrJitExtractChunks( // now points to HashAggrJitOutputRuntime rather than BaseVector directly. aggregateVector->resize(groups.size()); if (aggregateVector->encoding() == VectorEncoding::Simple::FLAT) { + // Derive the raw values kind from the output vector's actual storage + // type rather than slot.desc.accumulatorKind. They can differ: e.g. + // decimal avg's accumulatorKind is Int128 while its final result is a + // short decimal (FlatVector). Using accumulatorKind here would + // reinterpret an int64 buffer as int128 and corrupt the heap. + const auto outputKind = + hashAggrJitOutputValueKind(aggregateVector.get()); + if (!outputKind.has_value()) { + canRunChunk = false; + skipReason = "unsupported scalar output value kind"; + break; + } hashAggrJitOutputRuntimes_[slotIndex].scalar = jit::HashAggrJitScalarOutputRuntime{ .values = hashAggrJitRawOutputData( - aggregateVector.get(), slot.desc.accumulatorKind), + aggregateVector.get(), *outputKind), .nulls = aggregateVector->mutableRawNulls(), .vector = aggregateVector.get()}; } else if ( diff --git a/hashaggr_jit_refactor_plan.md b/hashaggr_jit_refactor_plan.md index 86a202a70..3f09a57e8 100644 --- a/hashaggr_jit_refactor_plan.md +++ b/hashaggr_jit_refactor_plan.md @@ -1364,3 +1364,95 @@ using HashAggrJitAddDenseFunc = - [ ] JIT add-dense ABI 传递的是 adapter-owned runtime payload; - [ ] row merge / row extract 不再出现 `rowField0/rowField1` 这类聚合专有字段名; - [ ] 新增一个 3-field intermediate 聚合时,不需要修改 InputAdapter/OutputAdapter 基类接口。 + +--- + +## 11. 事故复盘:`munmap_chunk(): invalid pointer`(commit `0722a59851` 引入) + +本章记录一次在 `HashAggrJitBenchmark` 上复现的堆破坏崩溃的完整定位过程与根因,作为 output runtime 绑定相关改动的回归警示。 + +### 11.1 现象 + +- 运行 `bolt_hashaggr_jit_benchmark`(RelWithDebInfo 行为,Release preset 构建)必崩。 +- 报错:`munmap_chunk(): invalid pointer`,`SIGABRT`。 +- 栈顶在算子关闭阶段析构中间结果 vector 时: + - `Driver::closeOperators()` → 释放 `RowVector` → 释放其 child `FlatVector`(int64) 的 values buffer → glibc `free` 检测到非法 chunk 指针。 +- hint:bug 出现在最近 5 个 commit 中。 + +### 11.2 定位过程 + +1. **缩小到具体 case**:在 benchmark `addCase()` 的 warmup 处加临时 `fprintf` 打印每个 case 名(每个 case warmup 时会先跑 nojit 再跑 jit)。运行后最后一条输出停在 `width4_merge_decimal_avg` 的 `jit` 阶段,**坐实崩溃 case = `width4_merge_decimal_avg`**。 +2. **bisect 到 commit**:先前已通过 `git reset --hard` 确认 first bad commit = `0722a59851`(其 parent `f752929ecc` 不崩)。 +3. **gdb 观察**: + - 崩溃发生在第二阶段(final aggregation)输出路径 `GroupingSet::runHashAggrJitExtractChunks`。 + - 在 decimal avg 的两个 helper 上下断:`jit_HashAggrExtractPartialDecimalAvg` 被调用 40000 次,但 `jit_HashAggrExtractFinalDecimalAvg` **一次都没进入**就崩了 → 说明堆已在 final 阶段“**extract 绑定阶段**”(`chunk.extract()` 之前)被破坏。 +4. **类型/精度推演**: + - `width4` 用 `DECIMAL(12,2)`(short decimal)。 + - decimal avg 中间 sum 类型按签名 `ROW(DECIMAL(38, a_scale), BIGINT)` → `DECIMAL(38,2)` 是 **long decimal(int128)**。 + - decimal avg final 结果类型 `r_precision=min(38,12+4)=16` → `DECIMAL(16,6)` 是 **short decimal**,存储为 **`FlatVector`**。 + - 但 descriptor 的 `accumulatorKind = Int128`(见 `AverageAggregate.cpp` 的 `DecimalAverageAggregate::createHashAggrJitDescriptor`)。 + +### 11.3 根因 + +`0722a59851` 把 `GroupingSet.cpp` 里的 `hashAggrJitRawOutputValues`(改名为 `hashAggrJitRawOutputData`)的 `Int128` 分支,从父 commit 的 `return nullptr` 改成了: + +```cpp +case jit::HashAggrJitValueKind::Int128: + return vector->asUnchecked>()->mutableRawValues(); +``` + +而 `runHashAggrJitExtractChunks` 的 **scalar final 输出绑定**(`GroupingSet.cpp:1306` 附近)用 `slot.desc.accumulatorKind` 来解释输出列: + +```cpp +.values = hashAggrJitRawOutputData(aggregateVector.get(), slot.desc.accumulatorKind) +``` + +对 decimal avg final:`accumulatorKind == Int128`,但 final 输出列真实类型是 short-decimal `FlatVector`。于是: + +1. 一个真实 `FlatVector` 被 `asUnchecked>()` 强转(类型混淆)。 +2. 调用 `mutableRawValues()`(见 `FlatVector.h:244`):此时 `values_` 是按 int64(8B/elem)分配且非 mutable,函数进入重分配分支: + - 按 `int128`(16B/elem)**重新分配 buffer**; + - `memcpy(newValues, rawValues_, byteSize(length))` 即按 2× 字节数从只有 8B/elem 的旧 buffer **越界读**; + - 把该 vector 的 `values_` / `rawValues_` 替换成 int128 尺寸 buffer。 +3. 这步破坏堆(越界读踩坏相邻 chunk metadata,并把列状态搞乱),最终在算子析构释放该 `RowVector`/`FlatVector` 链时 glibc 报 `munmap_chunk(): invalid pointer`。 + +**为何 parent commit 不崩**:原 `Int128` 分支 `return nullptr`,从不触碰该列 buffer。decimal avg final 真正写入走 helper `jit_HashAggrExtractFinalDecimalAvg`,由 `longDecimal` flag 正确按 int64/int128 写回,**根本不需要这个预取的 raw values 指针**。 + +**关键定性**:crash 由 commit `0722a59851` 的这一行引入(`bolt/exec/GroupingSet.cpp` 内 `hashAggrJitRawOutputData` 的 `Int128` 分支),与 scalar-output 绑定处用 `accumulatorKind` 解释 short-decimal 输出列的错配共同作用。它本质是一个 **`accumulatorKind` ≠ 输出 vector 实际存储类型** 的类型混淆。 + +### 11.4 验证 + +把 `Int128` 分支临时改回 `return nullptr`(仅验证用,注释说明 Int128 scalar/decimal 输出走 helper 的 `vector()`,不读此 raw 指针),重编译运行: + +- `width4/8/16/32_merge_decimal_avg` 全部通过,crash 消失; +- 整个 benchmark 跑完无 `munmap` / `Aborted`。 + +→ 根因实锤。 + +### 11.5 修复(已实施:方案 1) + +采用 §11.5 的方案 1:**scalar output 绑定按输出 vector 真实类型推导 kind**,而非 `accumulatorKind`。 + +`runHashAggrJitExtractChunks` 的 FLAT scalar 输出绑定改为: + +```cpp +const auto outputKind = hashAggrJitOutputValueKind(aggregateVector.get()); +if (!outputKind.has_value()) { + canRunChunk = false; + skipReason = "unsupported scalar output value kind"; + break; +} +... .values = hashAggrJitRawOutputData(aggregateVector.get(), *outputKind) ... +``` + +`hashAggrJitOutputValueKind` 已存在,会按列真实类型(含 short/long decimal)推导 kind,从而保证 `hashAggrJitRawOutputData` 取到的指针宽度与列存储宽度一致,杜绝 int64↔int128 错配重分配。`hashAggrJitRawOutputData` 的 `Int128` 分支保持正常实现(用于真正的 long-decimal/HUGEINT 输出列)。 + +其余备选方向(方案 2/3)未采用,记录备查: + +1. 对走 helper 的 decimal/Int128 输出不预取 raw values(保持 nullptr); +2. 统一约束指针宽度一致。 + +### 11.6 临时改动清理(已完成) + +- `bolt/exec/benchmarks/HashAggrJitBenchmark.cpp`:`addCase()` 内的 `fprintf` case 名打印 —— **已回退**。 +- `bolt/exec/GroupingSet.cpp`:`hashAggrJitRawOutputData` 的 `Int128` 分支临时 `return nullptr` —— **已恢复**为正常实现;正式修复落在 scalar 绑定处(见 §11.5)。 From 10d8047134d081c89d4c67ed9fb4224eaf8f6928 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 11:35:54 +0800 Subject: [PATCH 59/98] remove useless checkinputnulls in llvm ir codegen --- bolt/jit/aggregation/HashAggrJit.cpp | 11 +++-------- bolt/jit/aggregation/HashAggrJit.h | 4 +++- bolt/jit/aggregation/ops/AvgOps.cpp | 2 -- bolt/jit/aggregation/ops/CountOps.cpp | 2 -- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 2 -- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 2 -- bolt/jit/aggregation/ops/MinMaxOps.cpp | 1 - bolt/jit/aggregation/ops/SumOps.cpp | 1 - 8 files changed, 6 insertions(+), 19 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 67d43f6d7..bb52fe639 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -19,8 +19,6 @@ extern "C" { -using bytedance::bolt::jit::HashAggrJitInputRuntime; -using bytedance::bolt::jit::HashAggrJitOutputRuntime; using bytedance::bolt::jit::HashAggrJitRowInputRuntime; using bytedance::bolt::jit::HashAggrJitRowOutputRuntime; using bytedance::bolt::jit::HashAggrJitScalarInputRuntime; @@ -1007,7 +1005,7 @@ bool genAddDenseIR( if (addFn == nullptr) { return false; } - addFn(codegen, group, *input, row, slot, checkInputNulls, nextBlock); + addFn(codegen, group, *input, row, slot, nextBlock); builder.CreateBr(nextBlock); builder.SetInsertPoint(nextBlock); } @@ -1244,11 +1242,8 @@ bool HashAggrJitChunk::codegen() { addDenseNoNull_ = reinterpret_cast( module_->getFuncPtr(addNoNullFn)); extract_ = reinterpret_cast(module_->getFuncPtr(extractFn)); - if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr || - extract_ == nullptr) { - return false; - } - return true; + return init_ != nullptr && addDense_ != nullptr && + addDenseNoNull_ != nullptr && extract_ != nullptr; } } // namespace bytedance::bolt::jit diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index c62108817..eb1372741 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -32,7 +32,6 @@ struct HashAggrJitOps { const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot&, - bool checkInputNulls, llvm::BasicBlock* nextBlock); using CanExtractFn = bool (*)(const HashAggrJitSlot&, bool partialOutput); using ExtractFn = void (*)( @@ -50,8 +49,11 @@ struct HashAggrJitOps { }; struct HashAggrJitExtractTarget { + // Codegen adapter for the destination output vector to write results into. const OutputAdapterCodegen& output; + // The target row index (runtime llvm::Value) to write the extracted result. llvm::Value* row; + // Whether to emit partial (intermediate) results instead of final ones. bool partialOutput; }; diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index bca74a945..fd31df76f 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -45,7 +45,6 @@ void compileAvgAddRawInput( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { auto* inputRow = input.read(row, slot.desc.inputKind); auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); @@ -76,7 +75,6 @@ void compileAvgAddIntermediateResults( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { codegen.clearAccumulatorNull(group, slot); auto* sumRow = input.readRowField(row, 0, HashAggrJitValueKind::Double); diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp index cdd8067d5..330e1d27b 100644 --- a/bolt/jit/aggregation/ops/CountOps.cpp +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -42,7 +42,6 @@ void compileCountAddRawInput( const InputAdapterCodegen& /*input*/, llvm::Value* /*row*/, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { addInc(codegen, group, slot, codegen.builder().getInt64(1)); } @@ -53,7 +52,6 @@ void compileCountAddIntermediateResults( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { llvm::Value* inc = nullptr; if (slot.desc.countStar) { diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index 274e67f07..af5fc23ac 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -46,7 +46,6 @@ void compileDecimalAvgAddRawInput( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { auto& b = codegen.builder(); auto* inputRow = input.read(row, slot.desc.inputKind); @@ -72,7 +71,6 @@ void compileDecimalAvgAddIntermediateResults( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock* nextBlock) { auto& b = codegen.builder(); auto* function = b.GetInsertBlock()->getParent(); diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 38dff3963..cf26f70cf 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -46,7 +46,6 @@ void compileDecimalSumAddRawInput( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { auto& b = codegen.builder(); auto* inputRow = input.read(row, slot.desc.inputKind); @@ -66,7 +65,6 @@ void compileDecimalSumAddIntermediateResults( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock* nextBlock) { auto& b = codegen.builder(); auto* function = b.GetInsertBlock()->getParent(); diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp index ed101d108..9f2d645f6 100644 --- a/bolt/jit/aggregation/ops/MinMaxOps.cpp +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -32,7 +32,6 @@ void compileMinMaxUpdate( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { auto* inputRow = input.read(row, slot.desc.inputKind); auto* value = codegen.castValue( diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp index e2815db4d..a32717d4e 100644 --- a/bolt/jit/aggregation/ops/SumOps.cpp +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -34,7 +34,6 @@ void compileSumAccumulate( const InputAdapterCodegen& input, llvm::Value* row, const HashAggrJitSlot& slot, - bool, llvm::BasicBlock*) { auto* inputRow = input.read(row, slot.desc.inputKind); auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); From b327d62a5818266542092855b04d0abded6e7f9b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 12:12:56 +0800 Subject: [PATCH 60/98] remove useless code --- bolt/jit/aggregation/HashAggrJit.cpp | 2 ++ bolt/jit/aggregation/HashAggrJitTypes.h | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index bb52fe639..34226ce42 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1175,6 +1175,7 @@ bool isHashAggrJitSupportedType(TypeKind kind) { } } +/* std::string HashAggrJitDescriptor::signature() const { return fmt::format( "{}_{}_{}_{}_{}_{}_{}_{}", @@ -1187,6 +1188,7 @@ std::string HashAggrJitDescriptor::signature() const { hashAggrJitRuntimeShapeName(inputShape), hashAggrJitRuntimeShapeName(outputShape)); } +*/ bool HashAggrJitChunk::canExtract() const { if (extract_ == nullptr) { diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index b4d0fabed..201c02a4c 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -138,7 +138,7 @@ struct HashAggrJitDescriptor { int32_t auxScale{0}; const HashAggrJitOps* ops{nullptr}; - std::string signature() const; + // std::string signature() const; }; struct HashAggrJitSlot { From b9c3aee116adb0f722fdfa13f3c79ca49874ffc1 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 12:56:52 +0800 Subject: [PATCH 61/98] remove useless get null of row field --- bolt/jit/aggregation/HashAggrJit.cpp | 19 +++++++ bolt/jit/aggregation/HashAggrJit.h | 13 +++++ bolt/jit/aggregation/ops/AvgOps.cpp | 6 +-- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 4 +- hashaggr_jit_refactor_plan.md | 60 ++++++++++++++++++++++ 5 files changed, 96 insertions(+), 6 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 34226ce42..7a703196e 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -583,6 +583,13 @@ llvm::Value* ScalarInputAdapterCodegen::readRowField( BOLT_UNSUPPORTED("ScalarInputAdapterCodegen does not support ROW field load"); } +llvm::Value* ScalarInputAdapterCodegen::readRowFieldValue( + llvm::Value*, + int32_t, + HashAggrJitValueKind) const { + BOLT_UNSUPPORTED("ScalarInputAdapterCodegen does not support ROW field load"); +} + RowInputAdapterCodegen::RowInputAdapterCodegen( HashAggrJitCodegen& codegen, llvm::Value* input) @@ -621,6 +628,18 @@ llvm::Value* RowInputAdapterCodegen::readRowField( return IRRow::pack(codegen_.builder(), value, isRowFieldNull(row, field)); } +llvm::Value* RowInputAdapterCodegen::readRowFieldValue( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const { + // Skips the per-field null check CFG: returns only the raw field value. + // Valid when the field is guaranteed non-null on this path (its null bit is + // not consumed by the aggregate's merge semantics). + auto* child = loadChild(field); + return ::bytedance::bolt::jit::loadScalarInputValue( + codegen_.builder(), child, row, kind); +} + llvm::Value* RowInputAdapterCodegen::isRowFieldNull( llvm::Value* row, int32_t field) const { diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index eb1372741..e25d25f9a 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -116,6 +116,13 @@ class InputAdapterCodegen { llvm::Value* row, int32_t field, HashAggrJitValueKind kind) const = 0; + // Reads only the raw value of a ROW child, skipping the per-field null check + // CFG. Use when the framework guarantees the field is non-null on this path + // (i.e. the field's null bit is not consumed by the aggregate semantics). + virtual llvm::Value* readRowFieldValue( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const = 0; }; class ScalarInputAdapterCodegen final : public InputAdapterCodegen { @@ -128,6 +135,8 @@ class ScalarInputAdapterCodegen final : public InputAdapterCodegen { llvm::Value* isNull(llvm::Value* row) const override; llvm::Value* readRowField(llvm::Value*, int32_t, HashAggrJitValueKind) const override; + llvm::Value* readRowFieldValue(llvm::Value*, int32_t, HashAggrJitValueKind) + const override; private: HashAggrJitCodegen& codegen_; @@ -146,6 +155,10 @@ class RowInputAdapterCodegen final : public InputAdapterCodegen { llvm::Value* row, int32_t field, HashAggrJitValueKind kind) const override; + llvm::Value* readRowFieldValue( + llvm::Value* row, + int32_t field, + HashAggrJitValueKind kind) const override; private: llvm::Value* loadChild(int32_t field) const; diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index fd31df76f..fedd904dc 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -77,10 +77,8 @@ void compileAvgAddIntermediateResults( const HashAggrJitSlot& slot, llvm::BasicBlock*) { codegen.clearAccumulatorNull(group, slot); - auto* sumRow = input.readRowField(row, 0, HashAggrJitValueKind::Double); - auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); - auto* sum = IRRow::getValue(codegen.builder(), sumRow); - auto* count = IRRow::getValue(codegen.builder(), countRow); + auto* sum = input.readRowFieldValue(row, 0, HashAggrJitValueKind::Double); + auto* count = input.readRowFieldValue(row, 1, HashAggrJitValueKind::Int64); auto* oldSum = codegen.loadValue(group, codegen.builder().getDoubleTy(), slot.offset); codegen.storeValue( diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index cf26f70cf..eeaf18575 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -81,9 +81,9 @@ void compileDecimalSumAddIntermediateResults( auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "sum_decimal_merge", function, continueBlock); auto* sumRow = input.readRowField(row, 0, slot.desc.inputKind); - auto* isEmptyRow = input.readRowField(row, 1, HashAggrJitValueKind::Bool); + auto* incomingIsEmpty = + input.readRowFieldValue(row, 1, HashAggrJitValueKind::Bool); auto* sumIsNull = IRRow::getIsNull(b, sumRow); - auto* incomingIsEmpty = IRRow::getValue(b, isEmptyRow); auto* isNotEmpty = b.CreateICmpEQ(incomingIsEmpty, b.getInt8(0)); auto* isOverflow = b.CreateAnd(sumIsNull, isNotEmpty); b.CreateCondBr(isOverflow, overflowBlock, mergeBlock); diff --git a/hashaggr_jit_refactor_plan.md b/hashaggr_jit_refactor_plan.md index 3f09a57e8..bff2d87b7 100644 --- a/hashaggr_jit_refactor_plan.md +++ b/hashaggr_jit_refactor_plan.md @@ -1456,3 +1456,63 @@ if (!outputKind.has_value()) { - `bolt/exec/benchmarks/HashAggrJitBenchmark.cpp`:`addCase()` 内的 `fprintf` case 名打印 —— **已回退**。 - `bolt/exec/GroupingSet.cpp`:`hashAggrJitRawOutputData` 的 `Int128` 分支临时 `return nullptr` —— **已恢复**为正常实现;正式修复落在 scalar 绑定处(见 §11.5)。 + +--- + +## 12. 优化:ROW merge 输入跳过 per-field null 检查(`readRowFieldValue`) + +### 12.1 背景 + +`RowInputAdapterCodegen::readRowField` 对每个 ROW child 都会生成一段 per-field null 检查 CFG(`row_field_null_check` / `row_field_null_done` + PHI)来产出该 field 的 `is_null`,再 `IRRow::pack(value, is_null)`。 + +但 `addIntermediateResults`(merge)路径上,框架外层 `genAddDenseIR` 已对 **top-level ROW null** 统一发射过 null guard;ROW 内部各 field 是否需要 null 位,取决于具体聚合语义: + +- **avg merge**(`ROW(double sum, bigint count)`):业务上不读 field 的 null,只用 value; +- **decimal sum merge**(`ROW(decimal sum, bool isEmpty)`):仅 **sum** 字段的 null 被用来编码 overflow(JIT partial extract 溢出时 `sumVector->setNull`),`isEmpty` 字段的 null 不被消费; +- **decimal avg merge**(`ROW(decimal sum, bigint count)`):sum/count 的 null 都参与 overflow 判定,**不能跳过**。 + +因此对“null 位未被业务消费”的 field,生成 null 检查 CFG 是纯浪费。由于 `nulls` 指针编译期未知,这段 CFG 在 decimal 路径上 LLVM 往往**折不掉**,既增加 IR 体积也增加实际指令。 + +### 12.2 改动 + +新增 value-only 接口 `InputAdapterCodegen::readRowFieldValue(row, field, kind)`: + +- 语义:只返回 ROW child 的裸值(`llvm::Value*`),**跳过** per-field null 检查 CFG; +- 适用前提:该 field 在当前路径上保证非空(其 null 位不被聚合语义消费); +- `ScalarInputAdapterCodegen`:`BOLT_UNSUPPORTED`(与 `readRowField` 一致); +- `RowInputAdapterCodegen`:直接 `loadChild(field)` → `loadScalarInputValue(...)`,不调 `isRowFieldNull`。 + +调用点改造(严格按 null 是否被消费区分): + +| 调用点 | 处理 | +|--------|------| +| `compileAvgAddIntermediateResults`(sum/count) | 两字段全换 `readRowFieldValue` | +| `compileDecimalSumAddIntermediateResults` — `isEmpty` | 换 `readRowFieldValue` | +| `compileDecimalSumAddIntermediateResults` — `sum` | **保留** `readRowField`(`sumIsNull` 编码 overflow) | +| `compileDecimalAvgAddIntermediateResults` — sum/count | **保留** `readRowField`(两个 null 都参与 overflow 判定) | + +### 12.3 正确性依据 + +JIT 的 partial extract(`HashAggrDecimalRuntime.cpp` 的 `jit_HashAggrExtractPartialDecimalSum/Avg`)在 sum 溢出时会 `sumVector->setNull(row, true)`(整行 ROW 非 null)。下游 merge 正是靠读该 field 的 null 来识别并传播 overflow,故 decimal 的 sum 字段(及 decimal avg 的 count)**必须**保留 `readRowField`。注意:非 JIT 的 `extractAccumulators` 不置 sum-null,但 JIT pipeline 中上游可能是 JIT partial extract,跨阶段契约要求 merge 端兼容 sum=null。 + +### 12.4 性能验证(Release,jit 路径,单位 ms,越小越好) + +| case | baseline | optimized | 变化 | +|------|----------|-----------|------| +| width8_merge_avg | 6.16 | 6.00 | -2.6% | +| width16_merge_avg | 11.08 | 10.68 | -3.6% | +| width32_merge_avg | 21.47 | 20.51 | -4.5% | +| width4_merge_decimal_sum | 8.81 | 8.06 | -8.5% | +| width8_merge_decimal_sum | 15.90 | 15.48 | -2.6% | +| width16_merge_decimal_sum | 30.21 | 29.74 | -1.6% | +| width32_merge_decimal_sum | 61.05 | 60.14 | -1.5% | + +- avg merge:宽度越大收益越明显(每多一列多省一段 null CFG,线性放大); +- decimal sum merge:稳定小幅提升; +- `sum`(标量输入)不受影响;功能无回归,benchmark 无 crash。 + +### 12.5 涉及文件 + +- `bolt/jit/aggregation/HashAggrJit.h`:`InputAdapterCodegen` 新增 `readRowFieldValue` 纯虚 + 两子类声明; +- `bolt/jit/aggregation/HashAggrJit.cpp`:两子类实现; +- `bolt/jit/aggregation/ops/AvgOps.cpp`、`ops/DecimalSumOps.cpp`:按上表切换调用。 From d537d393640e476a4ae380a54c9e4950629e63e6 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 14:20:50 +0800 Subject: [PATCH 62/98] seperate decimal_sum decimal_avg operations into xxxops --- bolt/jit/aggregation/ops/DecimalOps.h | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 bolt/jit/aggregation/ops/DecimalOps.h diff --git a/bolt/jit/aggregation/ops/DecimalOps.h b/bolt/jit/aggregation/ops/DecimalOps.h new file mode 100644 index 000000000..423a5e289 --- /dev/null +++ b/bolt/jit/aggregation/ops/DecimalOps.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef ENABLE_BOLT_JIT + +#include "bolt/jit/aggregation/HashAggrJit.h" + +// Decimal-specific JIT codegen helper shared across decimal ops translation +// units. It lives with the decimal ops (rather than on the framework +// HashAggrJitCodegen) so decimal knowledge stays out of the generic framework, +// and is declared here because it is defined in DecimalSumOps.cpp but also used +// by DecimalAvgOps.cpp. Extract helpers that are used within a single TU stay +// file-local in their respective ops files. +namespace bytedance::bolt::jit { + +// Inline i128 accumulate-with-overflow used by decimal sum/avg add+merge. +// Loads the i128 sum at 'group + sumOffset' and the i64 overflow counter at +// 'group + overflowOffset', computes sum += addend, updates the overflow +// counter by the carry direction, and stores both back. +void emitDecimalAddWithOverflow( + HashAggrJitCodegen& codegen, + llvm::Value* group, + int32_t sumOffset, + int32_t overflowOffset, + llvm::Value* addend); + +} // namespace bytedance::bolt::jit + +#endif // ENABLE_BOLT_JIT From fca0131f204f05960d042a6e94c172fe9e0e92cb Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 14:28:17 +0800 Subject: [PATCH 63/98] remove unnecessary namespace prefix --- bolt/jit/aggregation/HashAggrJit.cpp | 142 ++++-------------- bolt/jit/aggregation/HashAggrJit.h | 26 ---- .../jit/aggregation/HashAggrJitDecimalState.h | 4 +- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 46 +++++- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 84 ++++++++++- 5 files changed, 153 insertions(+), 149 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 7a703196e..1d7206f35 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -493,38 +493,41 @@ HashAggrJitCodegen::HashAggrJitCodegen(llvm::Module& module) : module_(module) { } llvm::Type* HashAggrJitCodegen::llvmType(HashAggrJitValueKind kind) const { - return ::bytedance::bolt::jit::llvmType(builder(), kind); + // Qualified call: the unqualified name would resolve to this member (name + // hiding stops lookup at class scope), so qualify to reach the file-local + // free function. + return bytedance::bolt::jit::llvmType(builder(), kind); } llvm::Value* HashAggrJitCodegen::isInputNull( llvm::Value* nulls, llvm::Value* row) const { - return ::bytedance::bolt::jit::isInputNull(builder(), nulls, row); + return bytedance::bolt::jit::isInputNull(builder(), nulls, row); } llvm::Value* HashAggrJitCodegen::isAccumulatorNull( llvm::Value* group, const HashAggrJitSlot& slot) const { - return ::bytedance::bolt::jit::isAccumulatorNull(builder(), group, slot); + return bytedance::bolt::jit::isAccumulatorNull(builder(), group, slot); } void HashAggrJitCodegen::clearAccumulatorNull( llvm::Value* group, const HashAggrJitSlot& slot) const { - ::bytedance::bolt::jit::clearAccumulatorNull(builder(), group, slot); + bytedance::bolt::jit::clearAccumulatorNull(builder(), group, slot); } void HashAggrJitCodegen::setAccumulatorNull( llvm::Value* group, const HashAggrJitSlot& slot) const { - ::bytedance::bolt::jit::setAccumulatorNull(builder(), group, slot); + bytedance::bolt::jit::setAccumulatorNull(builder(), group, slot); } llvm::LoadInst* HashAggrJitCodegen::loadValue( llvm::Value* row, llvm::Type* type, int32_t offset) const { - return ::bytedance::bolt::jit::loadValue(builder(), row, type, offset); + return bytedance::bolt::jit::loadValue(builder(), row, type, offset); } void HashAggrJitCodegen::storeValue( @@ -532,18 +535,18 @@ void HashAggrJitCodegen::storeValue( llvm::Type* type, int32_t offset, llvm::Value* value) const { - ::bytedance::bolt::jit::storeValue(builder(), row, type, offset, value); + bytedance::bolt::jit::storeValue(builder(), row, type, offset, value); } llvm::Value* HashAggrJitCodegen::castValue( llvm::Value* value, HashAggrJitValueKind from, HashAggrJitValueKind to) const { - return ::bytedance::bolt::jit::castValue(builder(), value, from, to); + return bytedance::bolt::jit::castValue(builder(), value, from, to); } bool HashAggrJitCodegen::isFloatKind(HashAggrJitValueKind kind) const { - return ::bytedance::bolt::jit::isFloatKind(kind); + return bytedance::bolt::jit::isFloatKind(kind); } ScalarInputAdapterCodegen::ScalarInputAdapterCodegen( @@ -559,7 +562,7 @@ llvm::StructType* ScalarInputAdapterCodegen::irRowType( llvm::Value* ScalarInputAdapterCodegen::read( llvm::Value* row, HashAggrJitValueKind kind) const { - auto* value = ::bytedance::bolt::jit::loadScalarInputValue( + auto* value = loadScalarInputValue( codegen_.builder(), input_, row, kind); // add_dense emits the top-level null guard before invoking aggregate ops. // Therefore rows reaching ops are non-null; keep the IRRow contract explicit @@ -568,7 +571,7 @@ llvm::Value* ScalarInputAdapterCodegen::read( } llvm::Value* ScalarInputAdapterCodegen::loadNulls() const { - return ::bytedance::bolt::jit::loadScalarInputNulls( + return loadScalarInputNulls( codegen_.builder(), input_); } @@ -596,7 +599,7 @@ RowInputAdapterCodegen::RowInputAdapterCodegen( : codegen_(codegen), input_(input) {} llvm::Value* RowInputAdapterCodegen::loadChild(int32_t field) const { - return ::bytedance::bolt::jit::loadRowInputChild( + return loadRowInputChild( codegen_.builder(), input_, field); } @@ -611,7 +614,7 @@ llvm::Value* RowInputAdapterCodegen::read(llvm::Value*, HashAggrJitValueKind) } llvm::Value* RowInputAdapterCodegen::loadNulls() const { - return ::bytedance::bolt::jit::loadRowInputNulls(codegen_.builder(), input_); + return loadRowInputNulls(codegen_.builder(), input_); } llvm::Value* RowInputAdapterCodegen::isNull(llvm::Value* row) const { @@ -623,7 +626,7 @@ llvm::Value* RowInputAdapterCodegen::readRowField( int32_t field, HashAggrJitValueKind kind) const { auto* child = loadChild(field); - auto* value = ::bytedance::bolt::jit::loadScalarInputValue( + auto* value = loadScalarInputValue( codegen_.builder(), child, row, kind); return IRRow::pack(codegen_.builder(), value, isRowFieldNull(row, field)); } @@ -636,7 +639,7 @@ llvm::Value* RowInputAdapterCodegen::readRowFieldValue( // Valid when the field is guaranteed non-null on this path (its null bit is // not consumed by the aggregate's merge semantics). auto* child = loadChild(field); - return ::bytedance::bolt::jit::loadScalarInputValue( + return loadScalarInputValue( codegen_.builder(), child, row, kind); } @@ -645,7 +648,7 @@ llvm::Value* RowInputAdapterCodegen::isRowFieldNull( int32_t field) const { auto* child = loadChild(field); auto* nulls = - ::bytedance::bolt::jit::loadScalarInputNulls(codegen_.builder(), child); + loadScalarInputNulls(codegen_.builder(), child); auto* hasNulls = codegen_.builder().CreateICmpNE( nulls, llvm::ConstantPointerNull::get( @@ -659,7 +662,7 @@ llvm::Value* RowInputAdapterCodegen::isRowFieldNull( auto* noNullsEnd = codegen_.builder().GetInsertBlock(); codegen_.builder().SetInsertPoint(nullCheckBlock); - auto* index = ::bytedance::bolt::jit::loadScalarInputIndex( + auto* index = loadScalarInputIndex( codegen_.builder(), child, row); auto* isNull = codegen_.isInputNull(nulls, index); codegen_.builder().CreateBr(doneBlock); @@ -679,7 +682,7 @@ ScalarOutputAdapterCodegen::ScalarOutputAdapterCodegen( : codegen_(codegen), output_(output) {} llvm::Value* ScalarOutputAdapterCodegen::vector() const { - return ::bytedance::bolt::jit::loadScalarOutputVector( + return loadScalarOutputVector( codegen_.builder(), output_); } @@ -697,7 +700,7 @@ void ScalarOutputAdapterCodegen::write( auto* isNull = IRRow::getIsNull(codegen_.builder(), irRow); if (supportsRawFlatOutput(kind)) { auto* type = codegen_.llvmType(kind); - auto* values = ::bytedance::bolt::jit::loadScalarOutputValues( + auto* values = loadScalarOutputValues( codegen_.builder(), output_); auto* typedValues = codegen_.builder().CreatePointerCast( values, type->getPointerTo()); @@ -707,9 +710,9 @@ void ScalarOutputAdapterCodegen::write( codegen_.builder().CreateZExt(row, codegen_.builder().getInt64Ty())); auto* store = codegen_.builder().CreateStore(value, valueAddr); store->setAlignment(llvm::Align(1)); - auto* nulls = ::bytedance::bolt::jit::loadScalarOutputNulls( + auto* nulls = loadScalarOutputNulls( codegen_.builder(), output_); - ::bytedance::bolt::jit::emitOutputNullBit( + emitOutputNullBit( codegen_.builder(), nulls, row, isNull); return; } @@ -726,9 +729,9 @@ void ScalarOutputAdapterCodegen::writeField( void ScalarOutputAdapterCodegen::writeNull( llvm::Value* row, llvm::Value* isNull) const { - auto* nulls = ::bytedance::bolt::jit::loadScalarOutputNulls( + auto* nulls = loadScalarOutputNulls( codegen_.builder(), output_); - ::bytedance::bolt::jit::emitOutputNullBit( + emitOutputNullBit( codegen_.builder(), nulls, row, isNull); } @@ -738,12 +741,12 @@ RowOutputAdapterCodegen::RowOutputAdapterCodegen( : codegen_(codegen), output_(output) {} llvm::Value* RowOutputAdapterCodegen::loadChild(int32_t field) const { - return ::bytedance::bolt::jit::loadRowOutputChild( + return loadRowOutputChild( codegen_.builder(), output_, field); } llvm::Value* RowOutputAdapterCodegen::vector() const { - return ::bytedance::bolt::jit::loadRowOutputVector(codegen_.builder(), output_); + return loadRowOutputVector(codegen_.builder(), output_); } void RowOutputAdapterCodegen::resize(llvm::Value* size) const { @@ -771,7 +774,7 @@ void RowOutputAdapterCodegen::writeField( auto* value = IRRow::getValue(codegen_.builder(), irRow); auto* isNull = IRRow::getIsNull(codegen_.builder(), irRow); auto* type = codegen_.llvmType(kind); - auto* values = ::bytedance::bolt::jit::loadScalarOutputValues( + auto* values = loadScalarOutputValues( codegen_.builder(), child); auto* typedValues = codegen_.builder().CreatePointerCast(values, type->getPointerTo()); @@ -779,100 +782,21 @@ void RowOutputAdapterCodegen::writeField( auto* valueAddr = codegen_.builder().CreateInBoundsGEP(type, typedValues, row64); auto* store = codegen_.builder().CreateStore(value, valueAddr); store->setAlignment(llvm::Align(1)); - auto* nulls = ::bytedance::bolt::jit::loadScalarOutputNulls( + auto* nulls = loadScalarOutputNulls( codegen_.builder(), child); - ::bytedance::bolt::jit::emitOutputNullBit( + emitOutputNullBit( codegen_.builder(), nulls, row, isNull); } void RowOutputAdapterCodegen::writeNull( llvm::Value* row, llvm::Value* isNull) const { - auto* nulls = ::bytedance::bolt::jit::loadRowOutputNulls( + auto* nulls = loadRowOutputNulls( codegen_.builder(), output_); - ::bytedance::bolt::jit::emitOutputNullBit( + emitOutputNullBit( codegen_.builder(), nulls, row, isNull); } -void HashAggrJitCodegen::emitDecimalSumExtract( - llvm::Value* output, - llvm::Value* row, - llvm::Value* group, - const HashAggrJitSlot& slot, - bool partialOutput) const { - const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalSum" - : "jit_HashAggrExtractFinalDecimalSum"; - auto* longDecimal = builder().getInt8( - slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); - builder().CreateCall( - module_.getFunction(fn), - {output, - row, - group, - builder().getInt32(slot.offset), - builder().getInt32(slot.desc.precision), - builder().getInt32(slot.desc.scale), - longDecimal}); -} - -void HashAggrJitCodegen::emitDecimalAvgExtract( - llvm::Value* output, - llvm::Value* row, - llvm::Value* group, - const HashAggrJitSlot& slot, - bool partialOutput) const { - const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalAvg" - : "jit_HashAggrExtractFinalDecimalAvg"; - auto* longDecimal = builder().getInt8( - slot.desc.auxPrecision > bytedance::bolt::ShortDecimalType::kMaxPrecision - ? 1 - : 0); - builder().CreateCall( - module_.getFunction(fn), - {output, - row, - group, - builder().getInt32(slot.offset), - builder().getInt32(slot.desc.precision), - builder().getInt32(slot.desc.scale), - builder().getInt32(slot.desc.auxPrecision), - builder().getInt32(slot.desc.auxScale), - longDecimal}); -} - -void HashAggrJitCodegen::emitDecimalAddWithOverflow( - llvm::Value* group, - int32_t sumOffset, - int32_t overflowOffset, - llvm::Value* addend) const { - auto& b = builder(); - auto* i128Ty = b.getInt128Ty(); - auto* i64Ty = b.getInt64Ty(); - auto* zero128 = llvm::ConstantInt::get(i128Ty, 0); - - auto* oldSum = loadValue(group, i128Ty, sumOffset); - auto* newSum = b.CreateAdd(oldSum, addend); - storeValue(group, i128Ty, sumOffset, newSum); - - // Mirror jitHashAggrAddWithOverflow: - // +1 if a>0 && b>0 && result<0 (positive overflow) - // -1 if a<0 && b<0 && result>=0 (negative overflow) - auto* aPos = b.CreateICmpSGT(oldSum, zero128); - auto* bPos = b.CreateICmpSGT(addend, zero128); - auto* rNeg = b.CreateICmpSLT(newSum, zero128); - auto* posOverflow = b.CreateAnd(b.CreateAnd(aPos, bPos), rNeg); - - auto* aNeg = b.CreateICmpSLT(oldSum, zero128); - auto* bNeg = b.CreateICmpSLT(addend, zero128); - auto* rNonNeg = b.CreateICmpSGE(newSum, zero128); - auto* negOverflow = b.CreateAnd(b.CreateAnd(aNeg, bNeg), rNonNeg); - - auto* carry = b.CreateSub( - b.CreateZExt(posOverflow, i64Ty), b.CreateZExt(negOverflow, i64Ty)); - auto* oldOverflow = loadValue(group, i64Ty, overflowOffset); - storeValue(group, i64Ty, overflowOffset, b.CreateAdd(oldOverflow, carry)); -} - namespace { char hashAggrJitRuntimeShapeName(HashAggrJitRuntimeShape shape) { diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index e25d25f9a..529492543 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -263,32 +263,6 @@ class HashAggrJitCodegen { HashAggrJitValueKind from, HashAggrJitValueKind to) const; bool isFloatKind(HashAggrJitValueKind kind) const; - // Decimal extract: calls a runtime helper that reads the JIT decimal - // accumulator from 'group + slot.offset', applies overflow/precision checks - // and writes the result (final flat decimal / partial row) into 'vector'. - void emitDecimalSumExtract( - llvm::Value* vector, - llvm::Value* row, - llvm::Value* group, - const HashAggrJitSlot& slot, - bool partialOutput) const; - void emitDecimalAvgExtract( - llvm::Value* vector, - llvm::Value* row, - llvm::Value* group, - const HashAggrJitSlot& slot, - bool partialOutput) const; - - // Inline i128 accumulate-with-overflow used by decimal sum/avg add+merge. - // Loads the i128 sum at 'group + sumOffset' and the i64 overflow counter at - // 'group + overflowOffset', computes sum += addend, updates the overflow - // counter by the carry direction (mirrors jitHashAggrAddWithOverflow), and - // stores both back. Replaces the per-row runtime helper call with pure IR. - void emitDecimalAddWithOverflow( - llvm::Value* group, - int32_t sumOffset, - int32_t overflowOffset, - llvm::Value* addend) const; private: llvm::Module& module_; diff --git a/bolt/jit/aggregation/HashAggrJitDecimalState.h b/bolt/jit/aggregation/HashAggrJitDecimalState.h index 98a16c3d7..85271bf1e 100644 --- a/bolt/jit/aggregation/HashAggrJitDecimalState.h +++ b/bolt/jit/aggregation/HashAggrJitDecimalState.h @@ -14,13 +14,13 @@ namespace bytedance::bolt::jit { // planning/types header so non-decimal ops don't depend on aggregate-private // row state details. struct JitDecimalSumState { - bytedance::bolt::int128_t sum{0}; + int128_t sum{0}; int64_t overflow{0}; bool isEmpty{true}; }; struct JitDecimalAvgState { - bytedance::bolt::int128_t sum{0}; + int128_t sum{0}; int64_t count{0}; int64_t overflow{0}; }; diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index af5fc23ac..cbaf64c40 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -9,6 +9,8 @@ #include "bolt/jit/aggregation/HashAggrJit.h" #include "bolt/jit/aggregation/HashAggrJitDecimalState.h" +#include "bolt/jit/aggregation/ops/DecimalOps.h" +#include "bolt/type/Type.h" namespace bytedance::bolt::jit { @@ -53,8 +55,8 @@ void compileDecimalAvgAddRawInput( auto* value = codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - codegen.emitDecimalAddWithOverflow( - group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + emitDecimalAddWithOverflow( + codegen, group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); // ++count. auto* oldCount = codegen.loadValue(group, b.getInt64Ty(), slot.offset + kCountOffset); @@ -105,8 +107,8 @@ void compileDecimalAvgAddIntermediateResults( auto* value = codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - codegen.emitDecimalAddWithOverflow( - group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + emitDecimalAddWithOverflow( + codegen, group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); // count += incoming count. auto* oldCount = codegen.loadValue(group, b.getInt64Ty(), slot.offset + kCountOffset); @@ -127,13 +129,45 @@ bool canCompileDecimalAvgExtract(const HashAggrJitSlot&, bool partialOutput) { return true; } +void emitDecimalAvgExtract( + HashAggrJitCodegen& codegen, + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot, + bool partialOutput) { + auto& b = codegen.builder(); + const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalAvg" + : "jit_HashAggrExtractFinalDecimalAvg"; + auto* longDecimal = b.getInt8( + slot.desc.auxPrecision > bytedance::bolt::ShortDecimalType::kMaxPrecision + ? 1 + : 0); + b.CreateCall( + codegen.module().getFunction(fn), + {vector, + row, + group, + b.getInt32(slot.offset), + b.getInt32(slot.desc.precision), + b.getInt32(slot.desc.scale), + b.getInt32(slot.desc.auxPrecision), + b.getInt32(slot.desc.auxScale), + longDecimal}); +} + void compileDecimalAvgExtract( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { - codegen.emitDecimalAvgExtract( - target.output.vector(), target.row, group, slot, target.partialOutput); + emitDecimalAvgExtract( + codegen, + target.output.vector(), + target.row, + group, + slot, + target.partialOutput); } } // namespace diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index eeaf18575..f693ea56f 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -9,6 +9,7 @@ #include "bolt/jit/aggregation/HashAggrJit.h" #include "bolt/jit/aggregation/HashAggrJitDecimalState.h" +#include "bolt/jit/aggregation/ops/DecimalOps.h" namespace bytedance::bolt::jit { @@ -53,8 +54,12 @@ void compileDecimalSumAddRawInput( auto* value = codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - codegen.emitDecimalAddWithOverflow( - group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + emitDecimalAddWithOverflow( + codegen, + group, + slot.offset + kSumOffset, + slot.offset + kOverflowOffset, + value); codegen.storeValue( group, b.getInt8Ty(), slot.offset + kIsEmptyOffset, b.getInt8(0)); } @@ -97,8 +102,12 @@ void compileDecimalSumAddIntermediateResults( auto* value = codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); - codegen.emitDecimalAddWithOverflow( - group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); + emitDecimalAddWithOverflow( + codegen, + group, + slot.offset + kSumOffset, + slot.offset + kOverflowOffset, + value); // isEmpty = isEmpty && incomingIsEmpty. auto* oldIsEmpty = codegen.loadValue(group, b.getInt8Ty(), slot.offset + kIsEmptyOffset); @@ -119,13 +128,41 @@ bool canCompileDecimalSumExtract(const HashAggrJitSlot&, bool) { return true; } +void emitDecimalSumExtract( + HashAggrJitCodegen& codegen, + llvm::Value* vector, + llvm::Value* row, + llvm::Value* group, + const HashAggrJitSlot& slot, + bool partialOutput) { + auto& b = codegen.builder(); + const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalSum" + : "jit_HashAggrExtractFinalDecimalSum"; + auto* longDecimal = + b.getInt8(slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + b.CreateCall( + codegen.module().getFunction(fn), + {vector, + row, + group, + b.getInt32(slot.offset), + b.getInt32(slot.desc.precision), + b.getInt32(slot.desc.scale), + longDecimal}); +} + void compileDecimalSumExtract( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { - codegen.emitDecimalSumExtract( - target.output.vector(), target.row, group, slot, target.partialOutput); + emitDecimalSumExtract( + codegen, + target.output.vector(), + target.row, + group, + slot, + target.partialOutput); } } // namespace @@ -141,6 +178,41 @@ const HashAggrJitOps* getDecimalSumOps() { return &kOps; } +void emitDecimalAddWithOverflow( + HashAggrJitCodegen& codegen, + llvm::Value* group, + int32_t sumOffset, + int32_t overflowOffset, + llvm::Value* addend) { + auto& b = codegen.builder(); + auto* i128Ty = b.getInt128Ty(); + auto* i64Ty = b.getInt64Ty(); + auto* zero128 = llvm::ConstantInt::get(i128Ty, 0); + + auto* oldSum = codegen.loadValue(group, i128Ty, sumOffset); + auto* newSum = b.CreateAdd(oldSum, addend); + codegen.storeValue(group, i128Ty, sumOffset, newSum); + + // Mirror jitHashAggrAddWithOverflow: + // +1 if a>0 && b>0 && result<0 (positive overflow) + // -1 if a<0 && b<0 && result>=0 (negative overflow) + auto* aPos = b.CreateICmpSGT(oldSum, zero128); + auto* bPos = b.CreateICmpSGT(addend, zero128); + auto* rNeg = b.CreateICmpSLT(newSum, zero128); + auto* posOverflow = b.CreateAnd(b.CreateAnd(aPos, bPos), rNeg); + + auto* aNeg = b.CreateICmpSLT(oldSum, zero128); + auto* bNeg = b.CreateICmpSLT(addend, zero128); + auto* rNonNeg = b.CreateICmpSGE(newSum, zero128); + auto* negOverflow = b.CreateAnd(b.CreateAnd(aNeg, bNeg), rNonNeg); + + auto* carry = b.CreateSub( + b.CreateZExt(posOverflow, i64Ty), b.CreateZExt(negOverflow, i64Ty)); + auto* oldOverflow = codegen.loadValue(group, i64Ty, overflowOffset); + codegen.storeValue( + group, i64Ty, overflowOffset, b.CreateAdd(oldOverflow, carry)); +} + } // namespace bytedance::bolt::jit #endif // ENABLE_BOLT_JIT From 5b44abc12739c613b570e9f090a13e822d4cfd33 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 14:58:10 +0800 Subject: [PATCH 64/98] reusing sumcount in both non-jit and jit way --- .../lib/aggregates/AverageAggregateBase.h | 7 +-- bolt/functions/lib/aggregates/SumCount.h | 45 +++++++++++++++++++ bolt/jit/aggregation/ops/AvgOps.cpp | 10 +++-- 3 files changed, 52 insertions(+), 10 deletions(-) create mode 100644 bolt/functions/lib/aggregates/SumCount.h diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index f8869255f..aff93ec4a 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -33,6 +33,7 @@ #include "bolt/exec/Aggregate.h" #include "bolt/functions/lib/aggregates/AggregateToIntermediate.h" #include "bolt/functions/lib/aggregates/DecimalAggregate.h" +#include "bolt/functions/lib/aggregates/SumCount.h" #include "bolt/type/DecimalUtil.h" #include "bolt/vector/ComplexVector.h" #include "bolt/vector/DecodedVector.h" @@ -77,12 +78,6 @@ const SelectivityVector* getBaseRows( return baseRows; } -template -struct SumCount { - TSum sum{0}; - int64_t count{0}; -}; - } // namespace /// Partial aggregation produces a pair of sum and count. diff --git a/bolt/functions/lib/aggregates/SumCount.h b/bolt/functions/lib/aggregates/SumCount.h new file mode 100644 index 000000000..e7d41c521 --- /dev/null +++ b/bolt/functions/lib/aggregates/SumCount.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + */ + +#pragma once + +#include + +namespace bytedance::bolt::functions::aggregate { + +// Intermediate accumulator layout for AVG: a running sum and a count of +// non-null inputs. This is the single source of truth for the AVG intermediate +// memory layout. The hash-aggregation JIT codegen derives its field offsets +// from this struct (via offsetof) instead of mirroring the layout, so any +// change here is automatically picked up by the JIT path. +// +// Keep this header dependency-free (header-only template, no .cpp / no external +// symbols) so it can be included by both the non-JIT aggregates and the JIT +// module (bolt_thrustjit) without introducing inter-library link dependencies. +template +struct SumCount { + TSum sum{0}; + int64_t count{0}; +}; + +} // namespace bytedance::bolt::functions::aggregate diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index fedd904dc..23f978b56 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -9,14 +9,16 @@ #include +#include "bolt/functions/lib/aggregates/SumCount.h" + namespace bytedance::bolt::jit { namespace { -struct AvgAccumulatorLayout { - double sum; - int64_t count; -}; +// Single source of truth for the AVG intermediate layout: derive the JIT field +// offsets from the non-JIT SumCount struct so a change to SumCount is picked up +// here automatically instead of silently desyncing a mirrored copy. +using AvgAccumulatorLayout = functions::aggregate::SumCount; static_assert(std::is_standard_layout_v); From e643c3df7e83418342159b4113caf815682a125e Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 15:16:13 +0800 Subject: [PATCH 65/98] reusing accumulate structures in both non-jit and jit way --- .../lib/aggregates/DecimalAccumulatorLayout.h | 66 +++++++++++++++++++ .../lib/aggregates/DecimalAggregate.h | 9 ++- .../sparksql/aggregates/DecimalSumAggregate.h | 9 ++- .../jit/aggregation/HashAggrJitDecimalState.h | 26 +++----- 4 files changed, 84 insertions(+), 26 deletions(-) create mode 100644 bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h diff --git a/bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h b/bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h new file mode 100644 index 000000000..8dda85b1a --- /dev/null +++ b/bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * -------------------------------------------------------------------------- + * Copyright (c) ByteDance Ltd. and/or its affiliates. + * SPDX-License-Identifier: Apache-2.0 + * + * This modified file is released under the same license. + * -------------------------------------------------------------------------- + */ + +#pragma once + +#include + +#include "bolt/type/HugeInt.h" + +// Single source of truth for the in-memory layout of the decimal sum / decimal +// avg accumulators. The non-JIT accumulators (DecimalSum, +// LongDecimalWithOverflowState) derive from these POD layout bases, and the +// hash-aggregation JIT codegen aliases the same bases to compute field offsets +// (via offsetof) instead of mirroring the layout. A change here is therefore +// picked up by both paths automatically. +// +// Keep this header dependency-free (header-only PODs, no .cpp / no external +// symbols) so it can be included by both the non-JIT aggregates and the JIT +// module (bolt_thrustjit) without introducing inter-library link dependencies. +// +// IMPORTANT: classes deriving from these layouts must add *no* non-static data +// members (methods only), otherwise they stop being standard-layout and the +// offsets the JIT relies on become undefined. Each derived type guards this +// with static_assert(std::is_standard_layout_v<...>). +namespace bytedance::bolt::functions::aggregate { + +// Layout of the decimal sum accumulator: ROW(sum, isEmpty) with an overflow +// counter. Field order is part of the contract shared with the JIT path. +struct DecimalSumAccumulatorLayout { + int128_t sum{0}; + int64_t overflow{0}; + bool isEmpty{true}; +}; + +// Layout of the long-decimal-with-overflow accumulator (used by decimal avg): +// running sum, count of rows, and net overflow counter. NOTE: this is the +// in-memory layout {sum, count, overflow}, which differs from the serialized +// byte order {count, overflow, sum}; the JIT path reads memory, not serialized +// form, so it aligns with this layout. +struct LongDecimalWithOverflowLayout { + int128_t sum{0}; + int64_t count{0}; + int64_t overflow{0}; +}; + +} // namespace bytedance::bolt::functions::aggregate diff --git a/bolt/functions/lib/aggregates/DecimalAggregate.h b/bolt/functions/lib/aggregates/DecimalAggregate.h index 60466802d..6a0f9e496 100644 --- a/bolt/functions/lib/aggregates/DecimalAggregate.h +++ b/bolt/functions/lib/aggregates/DecimalAggregate.h @@ -32,6 +32,7 @@ #include "bolt/common/base/IOUtils.h" #include "bolt/exec/Aggregate.h" +#include "bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h" #include "bolt/type/HugeInt.h" #include "bolt/vector/FlatVector.h" namespace bytedance::bolt::functions::aggregate { @@ -42,7 +43,7 @@ namespace bytedance::bolt::functions::aggregate { * COUNT: Total number of rows so far. * OVERFLOW: Total count of net overflow or underflow so far. */ -struct LongDecimalWithOverflowState { +struct LongDecimalWithOverflowState : LongDecimalWithOverflowLayout { public: void mergeWith(const StringView& serializedData) { BOLT_CHECK_EQ(serializedData.size(), serializedSize()); @@ -75,12 +76,10 @@ struct LongDecimalWithOverflowState { static constexpr size_t serializedSize() { return sizeof(int64_t) * 4; } - - int128_t sum{0}; - int64_t count{0}; - int64_t overflow{0}; }; +static_assert(std::is_standard_layout_v); + template class DecimalAggregate : public exec::Aggregate { public: diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index fffe34e9a..350db6bfe 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -31,14 +31,11 @@ #pragma once #include "bolt/exec/Aggregate.h" #include "bolt/expression/FunctionSignature.h" +#include "bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h" #include "bolt/vector/FlatVector.h" namespace bytedance::bolt::functions::aggregate::sparksql { -struct DecimalSum { - int128_t sum{0}; - int64_t overflow{0}; - bool isEmpty{true}; - +struct DecimalSum : DecimalSumAccumulatorLayout { void mergeWith(const DecimalSum& other) { this->overflow += other.overflow; this->overflow += @@ -47,6 +44,8 @@ struct DecimalSum { } }; +static_assert(std::is_standard_layout_v); + template class DecimalSumAggregate : public exec::Aggregate { public: diff --git a/bolt/jit/aggregation/HashAggrJitDecimalState.h b/bolt/jit/aggregation/HashAggrJitDecimalState.h index 85271bf1e..d59699c47 100644 --- a/bolt/jit/aggregation/HashAggrJitDecimalState.h +++ b/bolt/jit/aggregation/HashAggrJitDecimalState.h @@ -5,25 +5,19 @@ #include #include -#include "bolt/type/Type.h" +#include "bolt/functions/lib/aggregates/DecimalAccumulatorLayout.h" namespace bytedance::bolt::jit { -// JIT-internal decimal accumulator layouts shared only by decimal aggregate -// codegen and decimal extract runtime helpers. Keep them out of the framework -// planning/types header so non-decimal ops don't depend on aggregate-private -// row state details. -struct JitDecimalSumState { - int128_t sum{0}; - int64_t overflow{0}; - bool isEmpty{true}; -}; - -struct JitDecimalAvgState { - int128_t sum{0}; - int64_t count{0}; - int64_t overflow{0}; -}; +// JIT-internal decimal accumulator layouts. These alias the shared POD layout +// bases that the non-JIT accumulators (DecimalSum / LongDecimalWithOverflowState) +// also derive from, so the JIT and non-JIT in-memory layouts stay in sync by +// construction (no mirrored copy to drift). The codegen / extract runtime read +// fields via offsetof on these aliases. +using JitDecimalSumState = + bytedance::bolt::functions::aggregate::DecimalSumAccumulatorLayout; +using JitDecimalAvgState = + bytedance::bolt::functions::aggregate::LongDecimalWithOverflowLayout; static_assert(std::is_standard_layout_v); static_assert(std::is_standard_layout_v); From 0389f8b726aee94f644275ed8cbc68e3a9a0628c Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 15:20:28 +0800 Subject: [PATCH 66/98] fix code style --- bolt/jit/aggregation/HashAggrJitDecimalState.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJitDecimalState.h b/bolt/jit/aggregation/HashAggrJitDecimalState.h index d59699c47..562daf594 100644 --- a/bolt/jit/aggregation/HashAggrJitDecimalState.h +++ b/bolt/jit/aggregation/HashAggrJitDecimalState.h @@ -14,10 +14,8 @@ namespace bytedance::bolt::jit { // also derive from, so the JIT and non-JIT in-memory layouts stay in sync by // construction (no mirrored copy to drift). The codegen / extract runtime read // fields via offsetof on these aliases. -using JitDecimalSumState = - bytedance::bolt::functions::aggregate::DecimalSumAccumulatorLayout; -using JitDecimalAvgState = - bytedance::bolt::functions::aggregate::LongDecimalWithOverflowLayout; +using JitDecimalSumState = functions::aggregate::DecimalSumAccumulatorLayout; +using JitDecimalAvgState = functions::aggregate::LongDecimalWithOverflowLayout; static_assert(std::is_standard_layout_v); static_assert(std::is_standard_layout_v); From 3faff33dccb759f3dd54e70cbed46bab378b01dc Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 15:28:08 +0800 Subject: [PATCH 67/98] rename runHashAggrJitChunks --- bolt/exec/GroupingSet.cpp | 4 ++-- bolt/exec/GroupingSet.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 175d95432..e778c9a69 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -540,7 +540,7 @@ void GroupingSet::addInputForActiveRows( auto& newGroups = lookup_->newGroups; std::vector jitExecuted; #ifdef ENABLE_BOLT_JIT - runHashAggrJitChunks(groups, newGroups, input, mayPushdown, jitExecuted); + runHashAggrJitAddChunks(groups, newGroups, input, mayPushdown, jitExecuted); #endif for (auto i = 0; i < aggregates_.size(); ++i) { if (!jitExecuted.empty() && jitExecuted[i]) { @@ -1124,7 +1124,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { << hashAggrJitChunks_.size(); } -void GroupingSet::runHashAggrJitChunks( +void GroupingSet::runHashAggrJitAddChunks( char** groups, folly::Range newGroups, const RowVectorPtr& input, diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index 3296a89f7..bdd150b79 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -288,7 +288,7 @@ class GroupingSet { #ifdef ENABLE_BOLT_JIT void maybeCreateHashAggrJitPlan(); - void runHashAggrJitChunks( + void runHashAggrJitAddChunks( char** groups, folly::Range newGroups, const RowVectorPtr& input, From e295039c8fdc8834fea912aca79d94c308950d67 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 15:56:24 +0800 Subject: [PATCH 68/98] minor refactor: split extract partial output and final output --- bolt/jit/aggregation/HashAggrJit.cpp | 8 ++-- bolt/jit/aggregation/HashAggrJit.h | 9 ++-- bolt/jit/aggregation/ops/AvgOps.cpp | 49 +++++++++++++--------- bolt/jit/aggregation/ops/CountOps.cpp | 22 +++++++++- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 21 ++++++++-- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 21 ++++++++-- bolt/jit/aggregation/ops/MinMaxOps.cpp | 22 +++++++++- bolt/jit/aggregation/ops/SumOps.cpp | 22 +++++++++- 8 files changed, 139 insertions(+), 35 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 1d7206f35..c977600f3 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1013,11 +1013,13 @@ bool genExtractIR( output = std::make_unique(codegen, outputRuntime); } - if (slot.desc.ops->extract == nullptr) { + auto* extractFn = partialOutput ? slot.desc.ops->extractAccumulators + : slot.desc.ops->extractResults; + if (extractFn == nullptr) { return false; } - slot.desc.ops->extract( - codegen, group, slot, HashAggrJitExtractTarget{*output, row, partialOutput}); + extractFn( + codegen, group, slot, HashAggrJitExtractTarget{*output, row}); } auto* next = builder.CreateAdd(row, builder.getInt32(1)); diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 529492543..46e95a72a 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -45,7 +45,12 @@ struct HashAggrJitOps { AddFn addRawInput; AddFn addIntermediateResults; CanExtractFn canExtract; - ExtractFn extract; + // Writes the intermediate (partial) accumulator state to the output, mirroring + // the non-JIT extractAccumulators path. + ExtractFn extractAccumulators; + // Writes the final aggregate result to the output, mirroring the non-JIT + // extractValues/extractResults path. + ExtractFn extractResults; }; struct HashAggrJitExtractTarget { @@ -53,8 +58,6 @@ struct HashAggrJitExtractTarget { const OutputAdapterCodegen& output; // The target row index (runtime llvm::Value) to write the extracted result. llvm::Value* row; - // Whether to emit partial (intermediate) results instead of final ones. - bool partialOutput; }; class IRRow { diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index 23f978b56..3cf84772d 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -104,7 +104,33 @@ bool canCompileAvgExtract(const HashAggrJitSlot& slot, bool) { return slot.desc.accumulatorKind == HashAggrJitValueKind::Double; } -void compileAvgExtract( +// Intermediate output is row(sum:double, count:bigint). All-null group yields +// (0, 0) with a non-null top-level row (isNull = 0), matching the non-JIT +// extractAccumulators path. +void compileAvgExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + auto& builder = codegen.builder(); + auto* sum = codegen.loadValue(group, builder.getDoubleTy(), slot.offset); + auto* count = codegen.loadValue( + group, builder.getInt64Ty(), slot.offset + kAvgCountOffset); + target.output.writeField( + target.row, + 0, + HashAggrJitValueKind::Double, + IRRow::pack(builder, sum, builder.getFalse())); + target.output.writeField( + target.row, + 1, + HashAggrJitValueKind::Int64, + IRRow::pack(builder, count, builder.getFalse())); + target.output.writeNull(target.row, builder.getFalse()); +} + +// Final output is double avg. count == 0 means all inputs were null -> null. +void compileAvgExtractValues( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot, @@ -113,24 +139,6 @@ void compileAvgExtract( auto* sum = codegen.loadValue(group, builder.getDoubleTy(), slot.offset); auto* count = codegen.loadValue( group, builder.getInt64Ty(), slot.offset + kAvgCountOffset); - if (target.partialOutput) { - // Intermediate output is row(sum:double, count:bigint). All-null group - // yields (0, 0) with a non-null top-level row (isNull = 0), matching the - // non-JIT extractAccumulators path. - target.output.writeField( - target.row, - 0, - HashAggrJitValueKind::Double, - IRRow::pack(builder, sum, builder.getFalse())); - target.output.writeField( - target.row, - 1, - HashAggrJitValueKind::Int64, - IRRow::pack(builder, count, builder.getFalse())); - target.output.writeNull(target.row, builder.getFalse()); - return; - } - // Final output is double avg. count == 0 means all inputs were null -> null. auto* isNull = builder.CreateICmpEQ(count, builder.getInt64(0)); auto* countAsDouble = builder.CreateSIToFP(count, builder.getDoubleTy()); auto* avg = builder.CreateFDiv(sum, countAsDouble); @@ -149,7 +157,8 @@ const HashAggrJitOps* getAvgOps() { &compileAvgAddRawInput, &compileAvgAddIntermediateResults, &canCompileAvgExtract, - &compileAvgExtract}; + &compileAvgExtractAccumulators, + &compileAvgExtractValues}; return &kOps; } diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp index 330e1d27b..fc81907f3 100644 --- a/bolt/jit/aggregation/ops/CountOps.cpp +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -71,6 +71,9 @@ bool canCompileCountExtract(const HashAggrJitSlot&, bool) { return true; } +// Count's intermediate accumulator and final result share the same scalar +// representation, so partial/final extract emit identical IR. The two named +// entry points below both forward to this helper. void compileCountExtract( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -84,6 +87,22 @@ void compileCountExtract( IRRow::pack(codegen.builder(), value, codegen.builder().getFalse())); } +void compileCountExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileCountExtract(codegen, group, slot, target); +} + +void compileCountExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileCountExtract(codegen, group, slot, target); +} + } // namespace const HashAggrJitOps* getCountOps() { @@ -93,7 +112,8 @@ const HashAggrJitOps* getCountOps() { &compileCountAddRawInput, &compileCountAddIntermediateResults, &canCompileCountExtract, - &compileCountExtract}; + &compileCountExtractAccumulators, + &compileCountExtractValues}; return &kOps; } diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index cbaf64c40..520516d34 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -156,7 +156,7 @@ void emitDecimalAvgExtract( longDecimal}); } -void compileDecimalAvgExtract( +void compileDecimalAvgExtractAccumulators( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot, @@ -167,7 +167,21 @@ void compileDecimalAvgExtract( target.row, group, slot, - target.partialOutput); + /*partialOutput=*/true); +} + +void compileDecimalAvgExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + emitDecimalAvgExtract( + codegen, + target.output.vector(), + target.row, + group, + slot, + /*partialOutput=*/false); } } // namespace @@ -179,7 +193,8 @@ const HashAggrJitOps* getDecimalAvgOps() { &compileDecimalAvgAddRawInput, &compileDecimalAvgAddIntermediateResults, &canCompileDecimalAvgExtract, - &compileDecimalAvgExtract}; + &compileDecimalAvgExtractAccumulators, + &compileDecimalAvgExtractValues}; return &kOps; } diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index f693ea56f..05c49410d 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -151,7 +151,7 @@ void emitDecimalSumExtract( longDecimal}); } -void compileDecimalSumExtract( +void compileDecimalSumExtractAccumulators( HashAggrJitCodegen& codegen, llvm::Value* group, const HashAggrJitSlot& slot, @@ -162,7 +162,21 @@ void compileDecimalSumExtract( target.row, group, slot, - target.partialOutput); + /*partialOutput=*/true); +} + +void compileDecimalSumExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + emitDecimalSumExtract( + codegen, + target.output.vector(), + target.row, + group, + slot, + /*partialOutput=*/false); } } // namespace @@ -174,7 +188,8 @@ const HashAggrJitOps* getDecimalSumOps() { &compileDecimalSumAddRawInput, &compileDecimalSumAddIntermediateResults, &canCompileDecimalSumExtract, - &compileDecimalSumExtract}; + &compileDecimalSumExtractAccumulators, + &compileDecimalSumExtractValues}; return &kOps; } diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp index 9f2d645f6..aadafe31a 100644 --- a/bolt/jit/aggregation/ops/MinMaxOps.cpp +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -78,6 +78,9 @@ bool canCompileMinMaxExtract(const HashAggrJitSlot& slot, bool) { slot.desc.accumulatorKind != HashAggrJitValueKind::Bool; } +// Min/max's intermediate accumulator and final result share the same scalar +// representation, so partial/final extract emit identical IR. The two named +// entry points below both forward to this helper. void compileMinMaxExtract( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -92,6 +95,22 @@ void compileMinMaxExtract( IRRow::pack(codegen.builder(), value, isNull)); } +void compileMinMaxExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileMinMaxExtract(codegen, group, slot, target); +} + +void compileMinMaxExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileMinMaxExtract(codegen, group, slot, target); +} + } // namespace const HashAggrJitOps* getMinMaxOps() { @@ -101,7 +120,8 @@ const HashAggrJitOps* getMinMaxOps() { &compileMinMaxUpdate, &compileMinMaxUpdate, &canCompileMinMaxExtract, - &compileMinMaxExtract}; + &compileMinMaxExtractAccumulators, + &compileMinMaxExtractValues}; return &kOps; } diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp index a32717d4e..92d2dccf6 100644 --- a/bolt/jit/aggregation/ops/SumOps.cpp +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -54,6 +54,9 @@ bool canCompileSumExtract(const HashAggrJitSlot& slot, bool) { slot.desc.accumulatorKind == HashAggrJitValueKind::Double; } +// Sum's intermediate accumulator and final result share the same scalar +// representation, so partial/final extract emit identical IR. The two named +// entry points below both forward to this helper. void compileSumExtract( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -68,6 +71,22 @@ void compileSumExtract( IRRow::pack(codegen.builder(), value, isNull)); } +void compileSumExtractAccumulators( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileSumExtract(codegen, group, slot, target); +} + +void compileSumExtractValues( + HashAggrJitCodegen& codegen, + llvm::Value* group, + const HashAggrJitSlot& slot, + const HashAggrJitExtractTarget& target) { + compileSumExtract(codegen, group, slot, target); +} + } // namespace const HashAggrJitOps* getSumOps() { @@ -77,7 +96,8 @@ const HashAggrJitOps* getSumOps() { &compileSumAccumulate, &compileSumAccumulate, &canCompileSumExtract, - &compileSumExtract}; + &compileSumExtractAccumulators, + &compileSumExtractValues}; return &kOps; } From d1e436c43722fe3fa3e4df74a6c57584b0620eaa Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 16:58:14 +0800 Subject: [PATCH 69/98] refactor: remove canExtract & support bool/int128 for extract --- bolt/jit/aggregation/HashAggrJit.cpp | 96 +++++++++++----------- bolt/jit/aggregation/HashAggrJit.h | 2 - bolt/jit/aggregation/ops/AvgOps.cpp | 6 -- bolt/jit/aggregation/ops/CountOps.cpp | 6 -- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 8 -- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 5 -- bolt/jit/aggregation/ops/MinMaxOps.cpp | 8 -- bolt/jit/aggregation/ops/SumOps.cpp | 7 -- 8 files changed, 46 insertions(+), 92 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index c977600f3..c8aef8b37 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -175,16 +175,15 @@ bool isFloatKind(HashAggrJitValueKind kind) { bool supportsRawFlatOutput(HashAggrJitValueKind kind) { switch (kind) { + case HashAggrJitValueKind::Bool: case HashAggrJitValueKind::Int8: case HashAggrJitValueKind::Int16: case HashAggrJitValueKind::Int32: case HashAggrJitValueKind::Int64: + case HashAggrJitValueKind::Int128: case HashAggrJitValueKind::Float: case HashAggrJitValueKind::Double: return true; - case HashAggrJitValueKind::Bool: - case HashAggrJitValueKind::Int128: - return false; } return false; } @@ -385,6 +384,31 @@ void emitOutputNullBit( builder.CreateSelect(isNullBool, nullWord, notNullWord), wordAddr); } +// Writes a scalar value into a flat output values buffer at 'row'. Bool uses +// bit-packed storage (one bit per row in a uint64 word array), written with the +// same word/mask/select pattern as the null bitmap; all other kinds are +// fixed-width and written with a direct store. +void emitFlatScalarValue( + llvm::IRBuilder<>& builder, + llvm::Value* values, + llvm::Value* row, + HashAggrJitValueKind kind, + llvm::Value* value) { + if (kind == HashAggrJitValueKind::Bool) { + auto* bit = value->getType()->isIntegerTy(1) + ? value + : builder.CreateICmpNE(value, builder.getInt8(0)); + emitOutputNullBit(builder, values, row, bit); + return; + } + auto* type = llvmType(builder, kind); + auto* typedValues = builder.CreatePointerCast(values, type->getPointerTo()); + auto* valueAddr = builder.CreateInBoundsGEP( + type, typedValues, builder.CreateZExt(row, builder.getInt64Ty())); + auto* store = builder.CreateStore(value, valueAddr); + store->setAlignment(llvm::Align(1)); +} + llvm::LoadInst* loadValue( llvm::IRBuilder<>& builder, llvm::Value* row, @@ -696,26 +720,16 @@ void ScalarOutputAdapterCodegen::write( llvm::Value* row, HashAggrJitValueKind kind, llvm::Value* irRow) const { - auto* value = IRRow::getValue(codegen_.builder(), irRow); - auto* isNull = IRRow::getIsNull(codegen_.builder(), irRow); - if (supportsRawFlatOutput(kind)) { - auto* type = codegen_.llvmType(kind); - auto* values = loadScalarOutputValues( - codegen_.builder(), output_); - auto* typedValues = codegen_.builder().CreatePointerCast( - values, type->getPointerTo()); - auto* valueAddr = codegen_.builder().CreateInBoundsGEP( - type, - typedValues, - codegen_.builder().CreateZExt(row, codegen_.builder().getInt64Ty())); - auto* store = codegen_.builder().CreateStore(value, valueAddr); - store->setAlignment(llvm::Align(1)); - auto* nulls = loadScalarOutputNulls( - codegen_.builder(), output_); - emitOutputNullBit( - codegen_.builder(), nulls, row, isNull); - return; - } + auto& builder = codegen_.builder(); + BOLT_CHECK( + supportsRawFlatOutput(kind), + "Unsupported raw flat scalar output kind for HashAggrJit"); + auto* value = IRRow::getValue(builder, irRow); + auto* isNull = IRRow::getIsNull(builder, irRow); + auto* values = loadScalarOutputValues(builder, output_); + emitFlatScalarValue(builder, values, row, kind, value); + auto* nulls = loadScalarOutputNulls(builder, output_); + emitOutputNullBit(builder, nulls, row, isNull); } void ScalarOutputAdapterCodegen::writeField( @@ -767,25 +781,17 @@ void RowOutputAdapterCodegen::writeField( int32_t field, HashAggrJitValueKind kind, llvm::Value* irRow) const { + auto& builder = codegen_.builder(); BOLT_CHECK( supportsRawFlatOutput(kind), "Unsupported raw ROW output field kind for HashAggrJit"); auto* child = loadChild(field); - auto* value = IRRow::getValue(codegen_.builder(), irRow); - auto* isNull = IRRow::getIsNull(codegen_.builder(), irRow); - auto* type = codegen_.llvmType(kind); - auto* values = loadScalarOutputValues( - codegen_.builder(), child); - auto* typedValues = - codegen_.builder().CreatePointerCast(values, type->getPointerTo()); - auto* row64 = codegen_.builder().CreateZExt(row, codegen_.builder().getInt64Ty()); - auto* valueAddr = codegen_.builder().CreateInBoundsGEP(type, typedValues, row64); - auto* store = codegen_.builder().CreateStore(value, valueAddr); - store->setAlignment(llvm::Align(1)); - auto* nulls = loadScalarOutputNulls( - codegen_.builder(), child); - emitOutputNullBit( - codegen_.builder(), nulls, row, isNull); + auto* value = IRRow::getValue(builder, irRow); + auto* isNull = IRRow::getIsNull(builder, irRow); + auto* values = loadScalarOutputValues(builder, child); + emitFlatScalarValue(builder, values, row, kind, value); + auto* nulls = loadScalarOutputNulls(builder, child); + emitOutputNullBit(builder, nulls, row, isNull); } void RowOutputAdapterCodegen::writeNull( @@ -1000,8 +1006,7 @@ bool genExtractIR( for (auto i = 0; i < slots.size(); ++i) { const auto& slot = slots[i]; - if (slot.desc.ops == nullptr || slot.desc.ops->canExtract == nullptr || - !slot.desc.ops->canExtract(slot, partialOutput)) { + if (slot.desc.ops == nullptr) { continue; } auto* outputAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); @@ -1136,16 +1141,7 @@ std::string HashAggrJitDescriptor::signature() const { */ bool HashAggrJitChunk::canExtract() const { - if (extract_ == nullptr) { - return false; - } - for (const auto& slot : slots_) { - if (slot.desc.ops == nullptr || slot.desc.ops->canExtract == nullptr || - !slot.desc.ops->canExtract(slot, partialOutput_)) { - return false; - } - } - return true; + return extract_ != nullptr; } bool HashAggrJitChunk::codegen() { diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 46e95a72a..1361e0c09 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -33,7 +33,6 @@ struct HashAggrJitOps { llvm::Value* row, const HashAggrJitSlot&, llvm::BasicBlock* nextBlock); - using CanExtractFn = bool (*)(const HashAggrJitSlot&, bool partialOutput); using ExtractFn = void (*)( HashAggrJitCodegen&, llvm::Value* group, @@ -44,7 +43,6 @@ struct HashAggrJitOps { CreateFn initGroup; AddFn addRawInput; AddFn addIntermediateResults; - CanExtractFn canExtract; // Writes the intermediate (partial) accumulator state to the output, mirroring // the non-JIT extractAccumulators path. ExtractFn extractAccumulators; diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index 3cf84772d..e8d061d1a 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -99,11 +99,6 @@ void compileAvgAddIntermediateResults( codegen.builder().CreateAdd(oldCount, count)); } -bool canCompileAvgExtract(const HashAggrJitSlot& slot, bool) { - // Only double avg accumulator layout is supported. - return slot.desc.accumulatorKind == HashAggrJitValueKind::Double; -} - // Intermediate output is row(sum:double, count:bigint). All-null group yields // (0, 0) with a non-null top-level row (isNull = 0), matching the non-JIT // extractAccumulators path. @@ -156,7 +151,6 @@ const HashAggrJitOps* getAvgOps() { &compileAvgInitGroup, &compileAvgAddRawInput, &compileAvgAddIntermediateResults, - &canCompileAvgExtract, &compileAvgExtractAccumulators, &compileAvgExtractValues}; return &kOps; diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp index fc81907f3..fe6b05679 100644 --- a/bolt/jit/aggregation/ops/CountOps.cpp +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -66,11 +66,6 @@ void compileCountAddIntermediateResults( addInc(codegen, group, slot, inc); } -bool canCompileCountExtract(const HashAggrJitSlot&, bool) { - // count result is always BIGINT and never null. - return true; -} - // Count's intermediate accumulator and final result share the same scalar // representation, so partial/final extract emit identical IR. The two named // entry points below both forward to this helper. @@ -111,7 +106,6 @@ const HashAggrJitOps* getCountOps() { &compileCountInitGroup, &compileCountAddRawInput, &compileCountAddIntermediateResults, - &canCompileCountExtract, &compileCountExtractAccumulators, &compileCountExtractValues}; return &kOps; diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index 520516d34..8cb44cb68 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -122,13 +122,6 @@ void compileDecimalAvgAddIntermediateResults( b.SetInsertPoint(continueBlock); } -bool canCompileDecimalAvgExtract(const HashAggrJitSlot&, bool partialOutput) { - // Both partial (extractAccumulators) and final extract go through runtime - // helpers. Final decimal avg keeps the divide/rescale logic in the helper to - // avoid duplicating Spark decimal semantics in LLVM IR. - return true; -} - void emitDecimalAvgExtract( HashAggrJitCodegen& codegen, llvm::Value* vector, @@ -192,7 +185,6 @@ const HashAggrJitOps* getDecimalAvgOps() { &compileDecimalAvgInitGroup, &compileDecimalAvgAddRawInput, &compileDecimalAvgAddIntermediateResults, - &canCompileDecimalAvgExtract, &compileDecimalAvgExtractAccumulators, &compileDecimalAvgExtractValues}; return &kOps; diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 05c49410d..aaac5f735 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -124,10 +124,6 @@ void compileDecimalSumAddIntermediateResults( b.SetInsertPoint(continueBlock); } -bool canCompileDecimalSumExtract(const HashAggrJitSlot&, bool) { - return true; -} - void emitDecimalSumExtract( HashAggrJitCodegen& codegen, llvm::Value* vector, @@ -187,7 +183,6 @@ const HashAggrJitOps* getDecimalSumOps() { &compileDecimalSumInitGroup, &compileDecimalSumAddRawInput, &compileDecimalSumAddIntermediateResults, - &canCompileDecimalSumExtract, &compileDecimalSumExtractAccumulators, &compileDecimalSumExtractValues}; return &kOps; diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp index aadafe31a..9dfa04ed7 100644 --- a/bolt/jit/aggregation/ops/MinMaxOps.cpp +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -71,13 +71,6 @@ void compileMinMaxUpdate( codegen.clearAccumulatorNull(group, slot); } -bool canCompileMinMaxExtract(const HashAggrJitSlot& slot, bool) { - // Flat setters exist for i8/i16/i32/i64/f32/f64 only. Int128 (long decimal) - // and Bool have no flat setter yet, fall back to non-JIT extract. - return slot.desc.accumulatorKind != HashAggrJitValueKind::Int128 && - slot.desc.accumulatorKind != HashAggrJitValueKind::Bool; -} - // Min/max's intermediate accumulator and final result share the same scalar // representation, so partial/final extract emit identical IR. The two named // entry points below both forward to this helper. @@ -119,7 +112,6 @@ const HashAggrJitOps* getMinMaxOps() { &compileMinMaxInitGroup, &compileMinMaxUpdate, &compileMinMaxUpdate, - &canCompileMinMaxExtract, &compileMinMaxExtractAccumulators, &compileMinMaxExtractValues}; return &kOps; diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp index 92d2dccf6..95cc236fa 100644 --- a/bolt/jit/aggregation/ops/SumOps.cpp +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -48,12 +48,6 @@ void compileSumAccumulate( codegen.storeValue(group, accType, slot.offset, newValue); } -bool canCompileSumExtract(const HashAggrJitSlot& slot, bool) { - // spark sum intermediate type == result type (bigint=bigint / double=double). - return slot.desc.accumulatorKind == HashAggrJitValueKind::Int64 || - slot.desc.accumulatorKind == HashAggrJitValueKind::Double; -} - // Sum's intermediate accumulator and final result share the same scalar // representation, so partial/final extract emit identical IR. The two named // entry points below both forward to this helper. @@ -95,7 +89,6 @@ const HashAggrJitOps* getSumOps() { &compileSumInitGroup, &compileSumAccumulate, &compileSumAccumulate, - &canCompileSumExtract, &compileSumExtractAccumulators, &compileSumExtractValues}; return &kOps; From daae495f104aa25383c5b2bb4613e20ce8ce06b8 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 17:02:29 +0800 Subject: [PATCH 70/98] add review document --- doc/hashaggr-jit-code-review.md | 133 ++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 doc/hashaggr-jit-code-review.md diff --git a/doc/hashaggr-jit-code-review.md b/doc/hashaggr-jit-code-review.md new file mode 100644 index 000000000..8688e6fdf --- /dev/null +++ b/doc/hashaggr-jit-code-review.md @@ -0,0 +1,133 @@ +# Hash Aggregation JIT 代码 Review 清单 + +> 对比范围:`d4e69030bfbe1d27eb31e6ad49027833bfce2c8e..HEAD`(hash aggr jit 支持代码,~7956 行 / 40 文件) +> 用途:逐条优化的工作清单。**本文档仅记录问题与建议,不含代码修改。** +> 关注维度:① 代码坏味道 ② JIT 与非 JIT 关键数据结构(raw input / intermediate / group / result)一致性 ③ 框架与聚合函数耦合残留 ④ 数据结构冗余 + +--- + +## 0. 总体结论 + +- **架构方向正确**:`HashAggrJitOps` 回调结构体已把各聚合语义下沉到 `ops/`,框架三大骨架 `genInitIR/genAddDenseIR/genExtractIR` 没有 `switch(kind)` 大分支。 +- **数据布局基准一致**:accumulator 起始 offset、null byte/mask 全部来自框架 `Aggregate::createHashAggrJitSlot`(`Aggregate.cpp:335`),JIT 不硬编码;内部字段用 `offsetof + static_assert(is_standard_layout)` 锁定。6 对聚合实现的字段顺序、ROW 子列顺序、null 标记、溢出语义**当前全部一致**。 +- **主要待优化点**: + 1. decimal 的 IR 生成(add-with-overflow / extract)泄漏在框架 `HashAggrJitCodegen` 中(耦合)。 + 2. 三大 gen 函数与 Sum/MinMax ops 存在大量样板重复(坏味道)。 + 3. JIT decimal state 与非 JIT accumulator 是「镜像复制」而非源码共享,缺跨编译单元交叉断言(一致性风险 + 冗余)。 + 4. descriptor 的 decimal 专属字段被所有 slot 冗余携带(数据冗余)。 + +优先级建议:先做 **C1(解耦 decimal)+ S1/S2(消骨架与 Sum/MinMax 重复)+ D1/D2(消 decimal 双定义与死字段)**,收益最大。 + +--- + +## 1. 框架与聚合函数耦合残留(关注点③) + +| # | 位置 | 问题 | 建议 | 严重度 | +|---|---|---|---|---| +| **C1** | `HashAggrJit.cpp:797-841` `emitDecimalSumExtract`/`emitDecimalAvgExtract`;`843-874` `emitDecimalAddWithOverflow`;头文件 `HashAggrJit.h:266-291` | decimal 专属 IR 生成(i128 累加+overflow 进位、调 decimal runtime、读 precision/scale/auxPrecision/auxScale、long/short decimal 判断)是**框架类 `HashAggrJitCodegen` 的成员**;ops 只是转手调用。decimal 知识泄漏进框架。 | 下沉到 `DecimalSumOps.cpp`/`DecimalAvgOps.cpp` 内 `static` 辅助函数,只依赖通用原语(loadValue/storeValue/builder/module)。框架类不应有任何带 "Decimal" 的方法。 | 高 ✅ **已完成** | +| **C2** | `HashAggrJit.cpp:118-148` `ensureBuiltinDeclarations` | 框架构造函数无条件声明 4 个 decimal extract runtime 签名,即使 chunk 内无 decimal。 | 给 `HashAggrJitOps` 增加可选 `declareRuntime(llvm::Module&)` 回调,由各 decimal ops 自行声明;框架只声明通用 `jit_HashAggrResizeVector`。 | 中 | +| **C3** | `HashAggrJit.cpp:89-102` `kHashAggrRuntimeLinkAnchors` | 框架 TU 的链接锚点引用了 decimal 专属符号 `jit_HashAggrExtractFinalDecimalSum`,使框架编译期强依赖 decimal runtime。 | 框架锚点改用通用 runtime 符号;decimal runtime 锚点由 decimal ops TU 自持。 | 中 | +| **C4** | `HashAggrJit.cpp:1002`(配套 `GroupingSet.cpp:1177`) | 框架骨架 `if (checkInputNulls && !slot.desc.countStar)` 直接判 count 专属 flag。 | 用通用语义字段(如 `consumesInput`/`hasScalarInput`)替代 `countStar` 在框架层的判断;count 语义判定保留在 `CountOps`。 | 低 | +| **C5** | `HashAggrJit.cpp:1115-1129`(及被注释的 `signature()` `1197-1210`) | chunk 名拼接直接读 `countStar/mergeInput/decimal/kind` 等具体 flag。属较合理的元数据消费,但仍耦合 flag 名。 | 可提供 `ops->signatureSuffix(slot)` 回调由算子补充自身特征。 | 低 | + +> 编排层(`GroupingSet.cpp` / `Aggregate.cpp` 的 slot/descriptor 构建)已通过虚函数 `supportsHashAggrJit`/`createHashAggrJitDescriptor` 干净解耦,无按 kind 写死映射。 + +--- + +## 2. JIT 与非 JIT 数据结构一致性(关注点②) + +### 2.1 逐对结论(当前均一致) + +| 聚合 | JIT | 非 JIT | 结论 | +|---|---|---|---| +| SUM | `SumOps.cpp` 单标量写 `slot.offset` | `SumAggregateBase.h` 裸标量 | ✅ 一致 | +| AVG | `AvgOps.cpp:16-23` `{double sum; int64 count}` + offsetof | `SumCount` `AverageAggregateBase.h:81-84` | ✅ 一致(ROW 子列 `{sum,count}`) | +| DECIMAL SUM | `JitDecimalSumState{sum,overflow,isEmpty}` + offsetof | `DecimalSum` `DecimalSumAggregate.h:37-48` | ✅ 一致(ROW `{sum,isEmpty}`,溢出哨兵语义一致) | +| DECIMAL AVG | `JitDecimalAvgState{sum,count,overflow}` + offsetof | `LongDecimalWithOverflowState` `DecimalAggregate.h:45-82` | ✅ 一致(ROW `{sum,count}`) | +| COUNT | `CountOps.cpp` 单 i64,结果永不 null | `CountAggregate.cpp` 裸 int64 | ✅ 一致 | +| MIN/MAX | `MinMaxOps.cpp` 单标量,null 表「空」 | `MinMaxAggregates.cpp` 裸标量 | ✅ 一致(Int128/Bool 回退非 JIT) | + +### 2.2 一致性风险点(靠人工同步维持,需重点盯) + +| # | 位置 | 问题 | 风险 | +|---|---|---|---| +| **R1** ✅ **已完成** | `AvgOps.cpp:16-23` vs `AverageAggregateBase.h:81-84` | ~~`AvgAccumulatorLayout` 与 `SumCount` 是两份独立定义,跨编译单元无法交叉 `static_assert(sizeof/offsetof==)`。改一处忘改另一处会静默写错 count 偏移。~~ 已抽出零依赖头 `SumCount.h` 作为唯一权威定义;JIT 端 `using AvgAccumulatorLayout = functions::aggregate::SumCount`,offset 由权威结构 `offsetof` 派生,自动同步,镜像漂移消除。 | 中 | +| **R2** ✅ **已完成** | `HashAggrJitDecimalState.h:16-26` vs `DecimalSumAggregate.h:37-48` / `DecimalAggregate.h:45-82` | ~~同为镜像复制。注意 DecimalSum 是 `{sum,overflow,isEmpty}`、DecimalAvg 是 `{sum,count,overflow}`,**overflow 字段位置不同**;且 `LongDecimalWithOverflowState::serialize()` 顺序(count,overflow,sum)又与内存布局不同,极易混淆。JIT 只读内存不走 serialize,当前正确但无编译期交叉校验。~~ 已抽出零依赖头 `DecimalAccumulatorLayout.h`(`DecimalSumAccumulatorLayout`/`LongDecimalWithOverflowLayout` 两个 POD 布局基类);`DecimalSum`/`LongDecimalWithOverflowState` 继承之(只加方法、不加数据成员,保持 standard-layout),JIT 端 `using JitDecimalSumState/JitDecimalAvgState` 别名同一布局基类。布局自动同步,4 处 `static_assert(is_standard_layout_v)` 兜底防派生类加字段。 | 中 | +| **R3** | `HashAggrDecimalRuntime.cpp:29-110` | 4 个 runtime helper 逐行复制了非 JIT 的 `computeFinalValue`/`computeAvg`/`adjustSumForOverflow`/`rescaleWithRoundUp` 及常量(`kCountPrecision=20` 等)。属行为复制,非 JIT 改 decimal 语义需同步两处。 | 中 | +| **R4** | `AvgOps.cpp:114-129` | 全 null group partial extract 输出 `(0,0)` 且 top-level 非 null,对齐的是 sparksql 重载版 `AverageAggregate.cpp:112-132`(非 lib 基类版)。需确认 JIT 仅用于 sparksql 路径。 | 低 | +| **R5** | `AvgOps.cpp:42-70`/`132` | avg 的最终 null 实际靠 `count==0` 判定,accumulator null byte 对 avg 不参与结果 null,存在但冗余,易误读。 | 低 | +| **R6** | `MinMaxOps.cpp:74-79`/`SumOps.cpp:51-55`/`AvgOps.cpp:100-103` | `canCompile*Extract` 用 accumulatorKind 白名单回退非 JIT,正确但靠人工维护;新增类型忘更新可能误走 JIT。 | 低 | + +--- + +## 3. 代码坏味道(关注点①) + +### 3.1 高严重度 + +| # | 位置 | 问题 | 建议 | +|---|---|---|---| +| **S1** | `HashAggrJit.cpp:906-925`/`953-978`/`1047-1067` | 三个 gen 函数的 LLVM 函数原型构造、entry/loop/end BB、`numRows<=0` guard、行循环 PHI、groupAddr/group 三段骨架几乎逐字重复 3 份。 | 抽取 `beginGroupLoop()/endGroupLoop()` 公共辅助返回 `{Function*, loop/end BB, PHI* row, group}`;`i8PtrTy/i32Ty/voidTy` 收进 `JitTypes` 缓存。 | +| **S2** | `SumOps.cpp:14-27`/`57-69` vs `MinMaxOps.cpp:14-25`/`81-93`(Avg init 前半段同) | Sum 与 MinMax 的 init(setNull+存0)与 extract(load+isNull+write)逐行相同。 | 抽 `compileZeroInitNullableAccumulator()` 与 `compileSimpleNullableExtract()` 复用。 | +| **S3** ✅ **已完成** | `HashAggrJitDecimalState.h:16-26` | ~~JIT decimal state 与非 JIT accumulator 重复定义,靠 `static_assert(standard_layout)` 无法保证与原结构字段顺序/对齐一致。~~ 已抽出零依赖布局基类(`DecimalAccumulatorLayout.h`),JIT 端用 `using` 别名复用,非 JIT 结构继承同一基类,布局单一权威来源。 | `using` 复用原结构,或加 `static_assert(sizeof/offsetof==)` 钉死并注释「布局必须与 X 同步」。 | + +### 3.2 中严重度 + +| # | 位置 | 问题 | 建议 | +|---|---|---|---| +| **S4** | `HashAggrJitTypes.h:141` + `HashAggrJit.cpp:1197-1210` | 被注释掉的死代码 `signature()`;``(`:15`) 仅服务这段死代码;逻辑与 chunk 名拼接重合。 | 删死代码或与 chunk 名拼接合并为真函数,移除多余 include。 | +| **S5** | `HashAggrJit.h:29-35` AddFn 的 `nextBlock` 参数 | 6/8 ops 实现未用(匿名 `BasicBlock*`),仅 decimal 用;框架调用后又无条件 `CreateBr(nextBlock)`(`:1028`),控制流职责模糊。 | 将「分支到 nextBlock」职责收归框架,decimal overflow 分流改用局部 if/PHI,从签名删除 `nextBlock`。 | +| **S6** | `SumOps.cpp:51`/`AvgOps.cpp:100`/`MinMaxOps.cpp:74`/`CountOps.cpp:69`/`DecimalAvgOps.cpp:123` | canExtract 第二参数有的匿名有的命名 `partialOutput` 却不用,风格不一。 | 统一:不用就一律匿名,typedef 处注释语义。 | +| **S7** | `HashAggrJit.cpp:692-716` `ScalarOutputAdapterCodegen::write` | `kind` 不支持时静默 no-op(既不写也不报错),与 `RowOutputAdapterCodegen::writeField`(`:762` 用 BOLT_CHECK) 不一致。 | 补 `else BOLT_UNSUPPORTED(...)`。 | +| **S8** | `HashAggrJit.cpp:1115-1129` | 超长 ostringstream 拼接函数名,单字符 flag(`s/x`,`g/r`,`d/n`)无注释,可维护性低。 | 抽 `appendSlotSignature(out, slot)` 并加注释。 | +| **S9** | gen 用 `return false`(`:937-939,1019-1026,1092-1094`);适配器用 `BOLT_UNSUPPORTED`(`:583,590,610,723,759`);`writeField` 用 `BOLT_CHECK`(`:767`) | 同模块对「不支持/非法状态」三种处理方式混用。 | 明确契约:可降级→bool,编程错误/不变量破坏→BOLT_CHECK,头注释写清。 | + +### 3.3 低严重度 + +| # | 位置 | 问题 | 建议 | +|---|---|---|---| +| **S10** | `HashAggrJit.cpp:995-1001`/`1085-1091` | 循环内每 slot `make_unique` 适配器(轻量值类型无需堆分配+虚表)。 | 用 `std::variant<...>` 或栈对象+基类引用。 | +| **S11** | `HashAggrJit.cpp:988`/`1077` | `for (auto i=0; iid` 字符串 | 算子身份 enum + 字符串双重标识;`kind` 实际仅 MinMax 区分用到(`MinMaxOps.cpp:48,61`)+ chunk 命名,Sum/Avg/Count 的 kind 与 ops 冗余。 | 只留 `ops` 指针给 MinMax 加 `isMin` 标志或拆两个 ops;或去 id 字符串改 kind 派生名,二选一。 | +| **D5** | `HashAggrJitTypes.h:128`(`decimal`)、`:129-130`(`inputShape/outputShape`) | `decimal` 与「ops 是否 Decimal*」一一对应;shape 与适配器选择(`HashAggrJit.cpp:888-894`)一一对应,属派生型冗余。 | ops 表暴露 `isDecimal`/`defaultShape` 后可去字段;否则注释「与 ops 绑定,prepare 填充」。 | +| **D6** | `HashAggrJitTypes.h:56`/`68` union 两变体各存 `void* vector` | scalar 与 row 输出变体都放一个语义相同的顶层 `vector` 字段。 | 把 `vector` 提到 union 外公共头部。 | +| **D7** | `AvgOps.cpp:16-19`(局部)vs `HashAggrJitDecimalState.h`(共享头);Sum/MinMax 裸标量是隐式约定无 struct | accumulator layout 存放位置与表达方式不一致。 | 统一:都用具名 struct+offsetof 或都注释化。 | + +### 4.3 低严重度 + +| # | 位置 | 问题 | 建议 | +|---|---|---|---| +| **D8** | `HashAggrJitTypes.h:100-109`;`HashAggrJit.cpp:152-154`/`176-189` | `Bool` 在 llvmType 等价 Int8,仅少数处特判,枚举语义重叠。 | 评估改 Int8+`isBool` 标志,或注释说明差异点。 | +| **D9** | `HashAggrJitTypes.h:144-152` | slot 与 desc **未**重复携带 offset/null(已确认);但 `desc` 按值内嵌使多 slot 拷贝整份 descriptor(与 D1 叠加)。 | 若 descriptor 可共享,slot 改持 `const HashAggrJitDescriptor*` 减少拷贝与死字段复制。 | + +--- + +## 5. 建议的优化顺序 + +1. **C1** ✅ **已完成**:decimal IR 生成已下沉到 ops(`emitDecimalAddWithOverflow`/`emitDecimalSumExtract` 定义于 `DecimalSumOps.cpp`,`emitDecimalAvgExtract` 定义于 `DecimalAvgOps.cpp`,声明在新增的 `ops/DecimalOps.h`;框架类 `HashAggrJitCodegen` 不再持有任何 "Decimal" 方法)。 +2. **S1 + S2**:消除三大 gen 骨架重复、合并 Sum/MinMax init/extract。 +3. **S3/D2 + R2** ✅ **已完成**:decimal state 双定义已改为继承共享 POD 布局基类(`DecimalAccumulatorLayout.h`)+ JIT `using` 别名复用,布局单一权威来源(同时降一致性风险与冗余)。 +4. **D1/D9**:descriptor decimal 死字段拆出 + slot 改持指针。 +5. **C2/C3**:decimal runtime 声明与链接锚点下沉。 +6. **R1** ✅ **已完成** / **R3**:R1(Avg layout)已通过抽出 `SumCount.h` + JIT `using` 复用消除镜像;R3(decimal runtime 逻辑复制)待加交叉校验/同步注释。 +7. 其余坏味道(S4–S15)、冗余(D3–D8)按批次清理。 From 88dc2bb64aa0ce7f5c15c0dda6e89a7c61b1fa5a Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 18:37:48 +0800 Subject: [PATCH 71/98] fix short/long decimal inconsistency in decimal avg/sum --- .../sparksql/aggregates/DecimalSumAggregate.h | 15 +- bolt/jit/aggregation/HashAggrJit.cpp | 49 ++-- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 33 ++- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 32 ++- .../runtime/HashAggrDecimalRuntime.cpp | 228 +++++++++++++----- hashaggr_jit_refactor_plan.md | 57 ++++- 6 files changed, 306 insertions(+), 108 deletions(-) diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 350db6bfe..16429c78f 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -83,7 +83,12 @@ class DecimalSumAggregate : public exec::Aggregate { } const auto& valueType = context.isRawInput ? context.inputType : context.inputType->childAt(0); - const auto [resultPrecision, resultScale] = + // Unified decimal precision/scale convention across decimal aggregates: + // precision/scale -> intermediate (partial) decimal type + // auxPrecision/auxScale -> final result decimal type + // For decimal sum the intermediate and final decimal type are the same + // (both sumType_), so aux* mirror precision/scale. + const auto [sumPrecision, sumScale] = getDecimalPrecisionScale(*sumType_.get()); return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::Sum, @@ -97,10 +102,10 @@ class DecimalSumAggregate : public exec::Aggregate { .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar : jit::HashAggrJitRuntimeShape::Row, .outputShape = jit::HashAggrJitRuntimeShape::Scalar, - .precision = resultPrecision, - .scale = resultScale, - .auxPrecision = 0, - .auxScale = 0, + .precision = sumPrecision, + .scale = sumScale, + .auxPrecision = sumPrecision, + .auxScale = sumScale, .ops = jit::getDecimalSumOps()}; } #endif diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index c8aef8b37..1bde63ba2 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -87,19 +87,18 @@ constexpr uint64_t kRowOutputVectorOffset = // object (and thus every helper it defines) to be retained. This TU is always // pulled in by any JIT user (HashAggrJitChunk), so the anchor propagates. void jit_HashAggrResizeVector(char* vector, int32_t size); -void jit_HashAggrExtractFinalDecimalSum( +void jit_HashAggrExtractFinalShortDecimalSum( char* vector, int32_t row, char* group, int32_t offset, int32_t precision, - int32_t scale, - int8_t longDecimal); + int32_t scale); [[maybe_unused]] __attribute__((used)) const void* const kHashAggrRuntimeLinkAnchors[] = { reinterpret_cast(&jit_HashAggrResizeVector), - reinterpret_cast(&jit_HashAggrExtractFinalDecimalSum)}; + reinterpret_cast(&jit_HashAggrExtractFinalShortDecimalSum)}; } // extern "C" @@ -117,34 +116,56 @@ llvm::FunctionCallee declareFunction( void ensureBuiltinDeclarations(llvm::Module& module) { auto& context = module.getContext(); - auto* i8Ty = llvm::Type::getInt8Ty(context); auto* i32Ty = llvm::Type::getInt32Ty(context); auto* voidTy = llvm::Type::getVoidTy(context); auto* i8PtrTy = llvm::PointerType::get(context, 0); declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); // Decimal extract helpers. - // Sum: (vector, row, group, offset, precision, scale, longDecimal). + // Sum: (vector, row, group, offset, precision, scale). declareFunction( module, - "jit_HashAggrExtractFinalDecimalSum", + "jit_HashAggrExtractFinalShortDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); declareFunction( module, - "jit_HashAggrExtractPartialDecimalSum", + "jit_HashAggrExtractFinalLongDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i8Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); declareFunction( module, - "jit_HashAggrExtractFinalDecimalAvg", + "jit_HashAggrExtractPartialShortDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty, i8Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); declareFunction( module, - "jit_HashAggrExtractPartialDecimalAvg", + "jit_HashAggrExtractPartialLongDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty, i8Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); + + // Avg: (vector, row, group, offset, precision, scale, resultPrecision, + // resultScale). + declareFunction( + module, + "jit_HashAggrExtractFinalShortDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractFinalLongDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialShortDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + declareFunction( + module, + "jit_HashAggrExtractPartialLongDecimalAvg", + voidTy, + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); } llvm::Type* llvmType(llvm::IRBuilder<>& builder, HashAggrJitValueKind kind) { diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index 8cb44cb68..754423701 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -24,6 +24,12 @@ constexpr int32_t kCountOffset = constexpr int32_t kOverflowOffset = static_cast(offsetof(JitDecimalAvgState, overflow)); +HashAggrJitValueKind decimalKindForPrecision(int32_t precision) { + return precision > bytedance::bolt::ShortDecimalType::kMaxPrecision + ? HashAggrJitValueKind::Int128 + : HashAggrJitValueKind::Int64; +} + void compileDecimalAvgInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -88,7 +94,8 @@ void compileDecimalAvgAddIntermediateResults( continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "avg_decimal_merge", function, continueBlock); - auto* sumRow = input.readRowField(row, 0, slot.desc.inputKind); + const auto sumKind = decimalKindForPrecision(slot.desc.precision); + auto* sumRow = input.readRowField(row, 0, sumKind); auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); auto* sumIsNull = IRRow::getIsNull(b, sumRow); auto* countIsNull = IRRow::getIsNull(b, countRow); @@ -104,8 +111,7 @@ void compileDecimalAvgAddIntermediateResults( b.SetInsertPoint(mergeBlock); auto* sum = IRRow::getValue(b, sumRow); - auto* value = - codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); + auto* value = codegen.castValue(sum, sumKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); emitDecimalAddWithOverflow( codegen, group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); @@ -130,12 +136,18 @@ void emitDecimalAvgExtract( const HashAggrJitSlot& slot, bool partialOutput) { auto& b = codegen.builder(); - const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalAvg" - : "jit_HashAggrExtractFinalDecimalAvg"; - auto* longDecimal = b.getInt8( - slot.desc.auxPrecision > bytedance::bolt::ShortDecimalType::kMaxPrecision - ? 1 - : 0); + // long/short decimal of the written sum column: partial output writes the + // intermediate sum decimal (precision/scale); final output writes the result + // decimal (auxPrecision/auxScale). + const int32_t outPrecision = + partialOutput ? slot.desc.precision : slot.desc.auxPrecision; + const bool longDecimal = + decimalKindForPrecision(outPrecision) == HashAggrJitValueKind::Int128; + const char* fn = partialOutput + ? (longDecimal ? "jit_HashAggrExtractPartialLongDecimalAvg" + : "jit_HashAggrExtractPartialShortDecimalAvg") + : (longDecimal ? "jit_HashAggrExtractFinalLongDecimalAvg" + : "jit_HashAggrExtractFinalShortDecimalAvg"); b.CreateCall( codegen.module().getFunction(fn), {vector, @@ -145,8 +157,7 @@ void emitDecimalAvgExtract( b.getInt32(slot.desc.precision), b.getInt32(slot.desc.scale), b.getInt32(slot.desc.auxPrecision), - b.getInt32(slot.desc.auxScale), - longDecimal}); + b.getInt32(slot.desc.auxScale)}); } void compileDecimalAvgExtractAccumulators( diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index aaac5f735..f9a772a5c 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -10,6 +10,7 @@ #include "bolt/jit/aggregation/HashAggrJit.h" #include "bolt/jit/aggregation/HashAggrJitDecimalState.h" #include "bolt/jit/aggregation/ops/DecimalOps.h" +#include "bolt/type/Type.h" namespace bytedance::bolt::jit { @@ -23,6 +24,12 @@ constexpr int32_t kOverflowOffset = constexpr int32_t kIsEmptyOffset = static_cast(offsetof(JitDecimalSumState, isEmpty)); +HashAggrJitValueKind decimalKindForPrecision(int32_t precision) { + return precision > bytedance::bolt::ShortDecimalType::kMaxPrecision + ? HashAggrJitValueKind::Int128 + : HashAggrJitValueKind::Int64; +} + void compileDecimalSumInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -85,7 +92,8 @@ void compileDecimalSumAddIntermediateResults( continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "sum_decimal_merge", function, continueBlock); - auto* sumRow = input.readRowField(row, 0, slot.desc.inputKind); + const auto sumKind = decimalKindForPrecision(slot.desc.precision); + auto* sumRow = input.readRowField(row, 0, sumKind); auto* incomingIsEmpty = input.readRowFieldValue(row, 1, HashAggrJitValueKind::Bool); auto* sumIsNull = IRRow::getIsNull(b, sumRow); @@ -99,8 +107,7 @@ void compileDecimalSumAddIntermediateResults( b.SetInsertPoint(mergeBlock); auto* sum = IRRow::getValue(b, sumRow); - auto* value = - codegen.castValue(sum, slot.desc.inputKind, HashAggrJitValueKind::Int128); + auto* value = codegen.castValue(sum, sumKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); emitDecimalAddWithOverflow( codegen, @@ -132,10 +139,18 @@ void emitDecimalSumExtract( const HashAggrJitSlot& slot, bool partialOutput) { auto& b = codegen.builder(); - const char* fn = partialOutput ? "jit_HashAggrExtractPartialDecimalSum" - : "jit_HashAggrExtractFinalDecimalSum"; - auto* longDecimal = - b.getInt8(slot.desc.inputKind == HashAggrJitValueKind::Int128 ? 1 : 0); + // long/short decimal is decided by the actual output decimal type, not the + // input kind: precision/scale carry the intermediate (partial) decimal type, + // auxPrecision/auxScale carry the final result decimal type. + const int32_t outPrecision = + partialOutput ? slot.desc.precision : slot.desc.auxPrecision; + const bool longDecimal = + decimalKindForPrecision(outPrecision) == HashAggrJitValueKind::Int128; + const char* fn = partialOutput + ? (longDecimal ? "jit_HashAggrExtractPartialLongDecimalSum" + : "jit_HashAggrExtractPartialShortDecimalSum") + : (longDecimal ? "jit_HashAggrExtractFinalLongDecimalSum" + : "jit_HashAggrExtractFinalShortDecimalSum"); b.CreateCall( codegen.module().getFunction(fn), {vector, @@ -143,8 +158,7 @@ void emitDecimalSumExtract( group, b.getInt32(slot.offset), b.getInt32(slot.desc.precision), - b.getInt32(slot.desc.scale), - longDecimal}); + b.getInt32(slot.desc.scale)}); } void compileDecimalSumExtractAccumulators( diff --git a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp index 9d115d6f3..208bb9a89 100644 --- a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp +++ b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp @@ -109,108 +109,91 @@ std::optional jitDecimalAvgComputeFinal( return status.ok() ? std::optional(rescaledValue) : std::nullopt; } -} // namespace - -extern "C" { - -// Final decimal sum extract: write FlatVector. Null when the group is -// empty (all inputs null) or the sum overflows the result precision. -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractFinalDecimalSum( +template +void jitHashAggrExtractFinalDecimalSum( char* vector, int32_t row, char* group, int32_t offset, - int32_t precision, - int32_t /*scale*/, - int8_t /*longDecimal*/) { + int32_t precision) { auto* state = reinterpret_cast(group + offset); auto* flat = reinterpret_cast(vector) - ->as>(); + ->asUnchecked>(); if (state->isEmpty) { flat->setNull(row, true); return; } + bool overflow = false; auto result = jitDecimalSumComputeFinal(state, precision, overflow); if (overflow) { flat->setNull(row, true); } else { - flat->set(row, result); + flat->set(row, static_cast(result)); } } -// Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractPartialDecimalSum( +template +void jitHashAggrExtractPartialDecimalSum( char* vector, int32_t row, char* group, int32_t offset, - int32_t precision, - int32_t /*scale*/, - int8_t /*longDecimal*/) { + int32_t precision) { auto* state = reinterpret_cast(group + offset); auto* rowVector = reinterpret_cast(vector) - ->as(); + ->asUnchecked(); auto* sumVector = - rowVector->childAt(0)->asFlatVector(); - auto* isEmptyVector = rowVector->childAt(1)->asFlatVector(); + rowVector->childAt(0)->asUnchecked>(); + auto* isEmptyVector = + rowVector->childAt(1)->asUnchecked>(); rowVector->setNull(row, false); if (state->isEmpty) { sumVector->set(row, 0); isEmptyVector->set(row, true); return; } + bool overflow = false; auto result = jitDecimalSumComputeFinal(state, precision, overflow); if (overflow) { sumVector->setNull(row, true); - isEmptyVector->set(row, false); } else { - sumVector->set(row, result); - isEmptyVector->set(row, state->isEmpty); + sumVector->set(row, static_cast(result)); } + isEmptyVector->set(row, overflow ? false : state->isEmpty); } -// Partial decimal avg extract: write row(sum:decimal, count:bigint). -// Overflow during sum adjustment -> sum child set to null, count kept. -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractPartialDecimalAvg( +template +void jitHashAggrExtractPartialDecimalAvg( char* vector, int32_t row, char* group, - int32_t offset, - int32_t /*precision*/, - int32_t /*scale*/, - int32_t /*resultPrecision*/, - int32_t /*resultScale*/, - int8_t /*longDecimal*/) { + int32_t offset) { auto* state = reinterpret_cast(group + offset); auto* rowVector = reinterpret_cast(vector) - ->as(); + ->asUnchecked(); auto* sumVector = - rowVector->childAt(0)->asFlatVector(); - auto* countVector = rowVector->childAt(1)->asFlatVector(); + rowVector->childAt(0)->asUnchecked>(); + auto* countVector = + rowVector->childAt(1)->asUnchecked>(); rowVector->setNull(row, false); countVector->set(row, state->count); std::optional adjustedSum = bytedance::bolt::DecimalUtil::adjustSumForOverflow( state->sum, state->overflow); if (adjustedSum.has_value()) { - sumVector->set(row, adjustedSum.value()); + sumVector->set(row, static_cast(adjustedSum.value())); } else { sumVector->setNull(row, true); } } -// Final decimal avg extract: write FlatVector. Null when -// the group is empty (all inputs null) or any overflow/rescale step fails. -__attribute__((__visibility__("default"))) void -jit_HashAggrExtractFinalDecimalAvg( +template +void jitHashAggrExtractFinalDecimalAvg( char* vector, int32_t row, char* group, @@ -218,35 +201,18 @@ jit_HashAggrExtractFinalDecimalAvg( int32_t precision, int32_t scale, int32_t resultPrecision, - int32_t resultScale, - int8_t longDecimal) { - auto* state = - reinterpret_cast(group + offset); - if (longDecimal) { - auto* flat = reinterpret_cast(vector) - ->as>(); - if (state->count == 0) { - flat->setNull(row, true); - return; - } - auto result = jitDecimalAvgComputeFinal( - state, precision, scale, resultPrecision, resultScale); - if (result.has_value()) { - flat->set(row, result.value()); - } else { - flat->setNull(row, true); - } - return; - } - + int32_t resultScale) { + auto* state = reinterpret_cast( + group + offset); auto* flat = reinterpret_cast(vector) - ->as>(); + ->asUnchecked>(); if (state->count == 0) { flat->setNull(row, true); return; } - auto result = - jitDecimalAvgComputeFinal(state, precision, scale, resultPrecision, resultScale); + + auto result = jitDecimalAvgComputeFinal( + state, precision, scale, resultPrecision, resultScale); if (result.has_value()) { flat->set(row, result.value()); } else { @@ -254,6 +220,134 @@ jit_HashAggrExtractFinalDecimalAvg( } } +} // namespace + +extern "C" { + +// Final decimal sum extract. Null when the group is empty (all inputs null) or +// the sum overflows the result precision. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalShortDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/) { + jitHashAggrExtractFinalDecimalSum( + vector, row, group, offset, precision); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalLongDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/) { + jitHashAggrExtractFinalDecimalSum( + vector, row, group, offset, precision); +} + +// Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialShortDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/) { + jitHashAggrExtractPartialDecimalSum( + vector, row, group, offset, precision); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialLongDecimalSum( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t /*scale*/) { + jitHashAggrExtractPartialDecimalSum( + vector, row, group, offset, precision); +} + +// Partial decimal avg extract: write row(sum:decimal, count:bigint). +// Overflow during sum adjustment -> sum child set to null, count kept. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialShortDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t /*precision*/, + int32_t /*scale*/, + int32_t /*resultPrecision*/, + int32_t /*resultScale*/) { + jitHashAggrExtractPartialDecimalAvg(vector, row, group, offset); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractPartialLongDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t /*precision*/, + int32_t /*scale*/, + int32_t /*resultPrecision*/, + int32_t /*resultScale*/) { + jitHashAggrExtractPartialDecimalAvg( + vector, row, group, offset); +} + +// Final decimal avg extract: write FlatVector. Null when +// the group is empty (all inputs null) or any overflow/rescale step fails. +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalShortDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t resultPrecision, + int32_t resultScale) { + jitHashAggrExtractFinalDecimalAvg( + vector, + row, + group, + offset, + precision, + scale, + resultPrecision, + resultScale); +} + +__attribute__((__visibility__("default"))) void +jit_HashAggrExtractFinalLongDecimalAvg( + char* vector, + int32_t row, + char* group, + int32_t offset, + int32_t precision, + int32_t scale, + int32_t resultPrecision, + int32_t resultScale) { + jitHashAggrExtractFinalDecimalAvg( + vector, + row, + group, + offset, + precision, + scale, + resultPrecision, + resultScale); +} + } // extern "C" #endif // ENABLE_BOLT_JIT diff --git a/hashaggr_jit_refactor_plan.md b/hashaggr_jit_refactor_plan.md index bff2d87b7..63aab908b 100644 --- a/hashaggr_jit_refactor_plan.md +++ b/hashaggr_jit_refactor_plan.md @@ -26,8 +26,6 @@ IRRow_t = llvm::StructType::get(value_type, i1_ty) `is_null` 永远在第二个字段,框架统一处理;`value_type` 内部结构对框架透明。 ---- - ## 1. 当前问题(背景) 落地前必须理解这些已存在的痛点,重构必须**逐项消除**。 @@ -1516,3 +1514,58 @@ JIT 的 partial extract(`HashAggrDecimalRuntime.cpp` 的 `jit_HashAggrExtractP - `bolt/jit/aggregation/HashAggrJit.h`:`InputAdapterCodegen` 新增 `readRowFieldValue` 纯虚 + 两子类声明; - `bolt/jit/aggregation/HashAggrJit.cpp`:两子类实现; - `bolt/jit/aggregation/ops/AvgOps.cpp`、`ops/DecimalSumOps.cpp`:按上表切换调用。 + +--- + +## 13. Decimal short/long 专用 helper 修复与性能结论(2026-06-13) + +本轮围绕 decimal sum/avg 的 short/long decimal 判断与 runtime helper 做了两类收敛: + +1. **descriptor 语义统一**:`precision/scale` 表示 intermediate/partial decimal 类型,`auxPrecision/auxScale` 表示 final result decimal 类型;decimal sum 因 partial/final 类型相同,所以 `aux*` 镜像 `precision/scale`。 +2. **short/long decimal 在 codegen 期固定**:不再把 `longDecimal` 作为外部 C++ runtime helper 参数传入,避免 LLVM 无法跨外部函数边界消除无效分支;`emitDecimalSumExtract` / `emitDecimalAvgExtract` 直接按实际输出精度选择 short/long 专用 helper。 + +当前专用 helper 形态: + +```text +jit_HashAggrExtractFinalShortDecimalSum +jit_HashAggrExtractFinalLongDecimalSum +jit_HashAggrExtractPartialShortDecimalSum +jit_HashAggrExtractPartialLongDecimalSum +jit_HashAggrExtractFinalShortDecimalAvg +jit_HashAggrExtractFinalLongDecimalAvg +jit_HashAggrExtractPartialShortDecimalAvg +jit_HashAggrExtractPartialLongDecimalAvg +``` + +验证命令: + +```bash +cmake --build --preset conan-release --target bolt_hashaggr_jit_benchmark --parallel 16 +./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_regex='(width8|width16)' +``` + +与用户提供的 baseline 单次结果相比,关键结论如下: + +| case | baseline | 当前 | 变化 | +|---|---:|---:|---:| +| `width8_merge_decimal_sum_jit` | 21.68ms | 15.24ms | **-29.70%** | +| `width16_merge_decimal_sum_jit` | 42.44ms | 29.47ms | **-30.56%** | +| `width8_merge_decimal_avg_jit` | 13.67ms | 13.73ms | +0.44% | +| `width16_merge_decimal_avg_jit` | 26.90ms | 26.45ms | -1.67% | + +汇总: + +| 分组 | 几何平均变化 | +|---|---:| +| 所有 JIT 项 | **-4.23%** | +| decimal JIT 项 | **-16.67%** | +| 非 decimal JIT 项 | -1.98% | + +因此,当前结论是: + +- `decimal_sum_jit` 从 baseline 的“慢于 nojit”变成“明显快于 nojit”: + - width8:15.24ms vs nojit 19.52ms,约 **21.9%** faster; + - width16:29.47ms vs nojit 39.11ms,约 **24.6%** faster。 +- `decimal_avg_jit` 基本持平,无系统性回退。 +- 非 decimal 项与本轮改动无直接关系,单次结果有正有负,整体未观察到系统性退化。 +- 后续 decimal extract/merge 相关重构应继续坚持:**能在 codegen 期确定的类型选择,不要作为 runtime 参数留给外部 helper 分支处理**。 From 671be25af4d31f3b373351d676f68913e79e78bc Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 18:38:10 +0800 Subject: [PATCH 72/98] add review report --- doc/hash_aggr_jit_state_consistency_review.md | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 doc/hash_aggr_jit_state_consistency_review.md diff --git a/doc/hash_aggr_jit_state_consistency_review.md b/doc/hash_aggr_jit_state_consistency_review.md new file mode 100644 index 000000000..8473bd5d0 --- /dev/null +++ b/doc/hash_aggr_jit_state_consistency_review.md @@ -0,0 +1,182 @@ +# Bolt HashAggrJit — JIT vs 非 JIT 中间状态一致性审计 + +> 审计分支:`origin/hash_aggr_jit`(原审计 HEAD: `2b6de6e186 remove useless code`) +> 比对范围:`bolt/jit/aggregation/`、`bolt/exec/GroupingSet.{cpp,h}`、`bolt/functions/{lib,sparksql,prestosql}/aggregates/` +> 审计目标:所有支持 JIT 的聚合算子,在所有支持的输入参数类型下,JIT 与非 JIT 的中间状态字节布局与运行时语义是否完全等价。 + +> **【复核更新 @ 当前工作区】** 本文档原始审计基于 `2b6de6e186`,下文已就地标注每条结论在当前工作区的状态: +> - **B3 已解决(更优方式)**:decimal/avg 的 JIT state 现为 `using` 别名直接指向非 JIT 继承的同一 POD 布局基类(`DecimalAccumulatorLayout.h` / `SumCount.h`),布局已是**同一类型**而非镜像,cross-assert 已多余。 +> - **B6 已过时**:Int128/Bool extract 已支持、`canCompileMinMaxExtract` 及整个 `CanExtractFn` 已删除。 +> - **B1/B2 已解决**:decimal sum/avg extract 已在 codegen 期按实际输出精度选择 short/long 专用 runtime helper,runtime 内不再保留 `longDecimal` 分支;`precision/scale` 与 `auxPrecision/auxScale` 已统一为 intermediate/final 语义。 +> - **B4 已解决(最小闭环)**:decimal sum/avg merge partial row 的 sum 字段读取 kind 已改为按 `precision` 推导,不再复用原始输入列 `inputKind`。 +> - **B5/B7 仍有效**:仅为防回归测试/断言加固类,P3。 + +--- + +## 结论一句话 +累加器**字节布局**已通过单一权威 POD 布局对齐;`decimal_sum` / `decimal_avg` 的短/长 decimal extract 与 partial merge 位宽判断也已修复。当前剩余项主要是 B5/B7 这类防回归测试或断言加固,不再有已知 P0/P1 的 JIT/非 JIT 状态不一致问题。 + +--- + +## A. 字节级一致性表 + +| 算子 | JIT 结构 | 非 JIT 结构 | 布局 | 等价? | +|---|---|---|---|---| +| `avg` | `AvgAccumulatorLayout{double sum; i64 count}` `AvgOps.cpp:16` | `SumCount` `AverageAggregateBase.h:81` | 16B | ✅ **【已修】** JIT 端 `using AvgAccumulatorLayout = SumCount`,同一类型 | +| `count(*/col)` | 单 `i64` `CountOps.cpp:18` | `sizeof(i64)` `CountAggregate.cpp:49` | 8B | ✅ | +| `sum` (int/float) | `TAccumulator` `SumOps.cpp:14` | `TAccumulator` `SumAggregateBase.h:78` | 同 | ✅ | +| `min/max` (非 i128) | `T` `MinMaxOps.cpp:14` | `T` `MinMaxAggregates.cpp:93` | 同 | ✅ | +| `decimal sum` | `JitDecimalSumState{i128 sum; i64 overflow; bool isEmpty}` `HashAggrJitDecimalState.h:16` | `DecimalSum{i128 sum; i64 overflow; bool isEmpty}` `DecimalSumAggregate.h:37` | 32B | ✅ **【已修】** JIT `using` 别名指向 `DecimalSumAccumulatorLayout`,`DecimalSum` 继承之,单一权威布局 | +| `decimal avg` | `JitDecimalAvgState{i128 sum; i64 count; i64 overflow}` `HashAggrJitDecimalState.h:22` | `LongDecimalWithOverflowState` 字段同 `DecimalAggregate.h:45` | 32B | ✅ **【已修】** 同上,`using` → `LongDecimalWithOverflowLayout`,继承复用 | + +--- + +## B. 真实差异(按严重度) + +### ✅ B1. `decimal_sum` partial/final extract 硬编码 `int128` 输出(已解决) + +> **【复核 @ 当前工作区:已解决】** `emitDecimalSumExtract` 已在 codegen 期按实际输出 decimal 精度选择专用 helper:partial 看 `precision`,final 看 `auxPrecision`。短 decimal 调 `jit_HashAggrExtract*ShortDecimalSum` 写 `FlatVector`,长 decimal 调 `jit_HashAggrExtract*LongDecimalSum` 写 `FlatVector`;decimal sum 的 descriptor 中 `auxPrecision/auxScale` 与 `precision/scale` 镜像为同一个 sum type。 + +历史旧代码(已删除)曾在 `HashAggrDecimalRuntime.cpp` 中硬编码: +```cpp +vector->as>() // final: 直接吃 raw vector,不看 longDecimal 参数 +rowVector->childAt(0)->asFlatVector() // partial 同上 +``` + +- 调用方 `emitDecimalSumExtract` 不再传 `longDecimal`;short/long 选择已下沉为不同 runtime symbol。 +- `canCompileDecimalSumExtract` (`DecimalSumOps.cpp:118`) 无条件 `return true`,没有任何回退门。 +- Spark 注册签名 `r_precision = min(38, a_precision + 10)`(`SumAggregate.cpp:117`),factory 在 `sumType->isShortDecimal()` 时显式构造 `DecimalSumAggregate`(`SumAggregate.cpp:158`)。 +- 结果:`sum(DECIMAL(5,2))` / `sum(DECIMAL(8,3))` 的结果列就是 `FlatVector`,JIT 那里 `dynamic_cast>` 得到 `nullptr` → 空指针 set/setNull → **段错误 / 堆破坏**。 +- `GroupingSet.cpp:1304-1308` 的注释已显式承认 *"decimal avg's accumulatorKind is Int128 while its final result is a short decimal (FlatVector)"*——但那只保护 scalar output 自动按 vector 真实 type 推 kind 的路径;走 runtime helper 自己再 cast 一次的路径完全没保护到。 + +### ✅ B2. `decimal_avg` partial extract 同病(已解决) + +> **【复核 @ 当前工作区:已解决】** `emitDecimalAvgExtract` 已在 codegen 期选择 short/long 专用 runtime helper:partial 看中间 sum 精度 `precision`,final 看结果精度 `auxPrecision`。`jit_HashAggrExtractPartial*DecimalAvg` 与 final 路径一样,分别写短/长 decimal sum child。 + +历史问题:`HashAggrDecimalRuntime.cpp` 曾硬写 `asFlatVector`。Spark AVG 第二条签名 `ROW(DECIMAL(a_precision, a_scale), BIGINT)` 会沿用入参精度——短 decimal 时 partial 输出是 `int64` sum vector,旧实现会 crash。当前已通过 short/long 专用 helper 消除该风险,且 runtime 内无无效分支。 + +### 🟡 B3. JIT/非 JIT 结构没有跨层 `static_assert`(Major) + +> **【复核 @ `b4b99b5553`:已解决(更优方式),无需再加 static_assert】** 现已抽出零依赖 POD 布局基类(`DecimalAccumulatorLayout.h` 的 `DecimalSumAccumulatorLayout`/`LongDecimalWithOverflowLayout`、`SumCount.h` 的 `SumCount`):非 JIT 结构 `DecimalSum`/`LongDecimalWithOverflowState` **继承**之,JIT 端 `JitDecimalSumState`/`JitDecimalAvgState`/`AvgAccumulatorLayout` 用 `using` **别名同一基类**。两侧已是**同一类型**而非镜像副本,sizeof/offsetof 必然相等,cross-assert 已多余;各处保留 `static_assert(is_standard_layout_v)` 防止派生类误加数据成员破坏布局。 + +`HashAggrJitDecimalState.h:28-29` 只断言了 `is_standard_layout`,缺: +```cpp +static_assert(sizeof(JitDecimalSumState) == sizeof(sparksql::DecimalSum)); +static_assert(offsetof(JitDecimalSumState, sum) == offsetof(sparksql::DecimalSum, sum)); +// overflow / isEmpty 同理; +// JitDecimalAvgState vs LongDecimalWithOverflowState 同理; +// AvgAccumulatorLayout vs SumCount 同理。 +``` +ABI 完全靠手工同步,加 4 行最便宜也最实在。 + +### ✅ B4. row 输入 stride 仍按 plan 端 `slot.desc.inputKind`(已解决) + +> **【复核 @ 当前工作区:已解决】** 对 decimal sum/avg 的 partial merge,row field 0 是中间 sum decimal,真实位宽由中间精度 `precision` 决定,而不是原始输入列 `slot.desc.inputKind`。当前 `DecimalSumOps.cpp` / `DecimalAvgOps.cpp` 均通过 `decimalKindForPrecision(slot.desc.precision)` 读取 row field 并 cast 到 accumulator `Int128`。 + +历史问题:`DecimalSumOps.cpp` / `DecimalAvgOps.cpp` 读 row field 曾用 `slot.desc.inputKind`;runtime `fillHashAggrJitRowInputRuntime` 又按 vector 真实类型再反推一次。两侧不一致时 stride 错。当前在不扩 descriptor 的前提下,先利用已有 `precision` 作为 codegen 期 single source of truth,消除了 decimal partial row 的位宽漂移。 + +### 🟢 B5. `MinAggregate` 初值不同(语义等价,但容易看错) + +> **【复核 @ `b4b99b5553`:已验证 JIT == 非 JIT,Spark 下也一致;仅需补防回归测试】** +> 注意前提:Spark 与 Presto 的 min/max **共用同一份 `registerMinMax` + 同一个 `MinAggregate`/`MaxAggregate` + 同一份 JIT op**,所以确实需要验,但结论是一致的。 +> - 非 JIT 权威比较是 `SimpleVector::comparePrimitiveAsc`(`SimpleVector.h:368-380`):**NaN 视为最大**(NaN 排在所有非 NaN 之后),且该语义**不随 `SPARK_COMPATIBLE` 改变**——Spark/Presto 统一。 +> - JIT op(`MinMaxOps.cpp:45-59`)逐组合等价于 NaN=最大: +> - Min:`(oldIsNan && !valueIsNan) || (!valueIsNan && old>value)` → 避开 NaN,仅全 NaN 时结果为 NaN; +> - Max:`!oldIsNan && (valueIsNan || old - 对 {NaN,非NaN} 全部四种组合手工核对,结果与 `comparePrimitiveAsc` 完全一致。 +> - 初值 `0.0`(JIT)/ `NaN`(非 JIT Presto)都不参与比较:首条非 null 输入必定 `nullState=true` 无条件覆盖初值(`shouldStore = nullState || better`)。 +> - **结论**:B5 不是 bug,语义在 Spark 与 Presto 下均一致。剩余价值仅为**防回归**:补 `max(NaN,5.0)`/`min(NaN,5.0)` 的 JIT vs 非 JIT 对照用例钉死等价性,优先级 P3。 + +- 非 JIT Presto: `kInitialValue_ = NaN` (`MinMaxAggregates.cpp:367-371`) +- JIT: 统一写 `0.0` (`MinMaxOps.cpp:21`);靠 `shouldStore = nullState || better` 让第一条非 null 输入无条件覆盖 +- 我手算了所有 NaN/非 NaN 组合,Presto 下结果一致;**但 Spark MinMax 的 NaN 排序语义和 Presto 不同**,如果 Spark 也走同一份 JIT op,需要再验。 + +### 🟢 B6. `MinMax` 混合路径 + +> **【复核 @ `b4b99b5553`:已过时】** Int128/Bool 的 extract 已实现,`canCompileMinMaxExtract` 及整个 `CanExtractFn` 已删除,extract 不再走非 JIT 混合路径。B6 描述的现象不复存在。null 槽布局一致性的 NOTE 注释建议仍可保留参考。 + +`canCompileMinMaxExtract` (`MinMaxOps.cpp:74-79`) 对 `Int128`/`Bool` 返回 `false`,extract 走非 JIT,init/update 走 JIT。当前 `slot.nullByte/nullMask` 来自 `RowContainer::nullByte/Mask`(`GroupingSet.cpp:766`),和 `exec::Aggregate::isNull` 一致——OK,但建议加一条 NOTE 注释防止后续重构改 null 槽布局踩坑。 + +### 🟢 B7. 整数 `sum` 溢出 + +> **【复核 @ `b4b99b5553`:Spark 语义下不是 bug,结论一致】** +> - 走 JIT 的整数 sum **只有 Spark**:Presto 的 sum 未注册 `supportsHashAggrJit`(prestosql 下仅 Count/MinMax 接入 JIT),因此不存在"Presto sum 复用 JIT"的实际路径。注意 sum 与 min/max 不同——min/max 是 Spark/Presto 共用注册,sum 各自独立注册(Spark 有自己的 `registerSumAggregate`)。 +> - 非 JIT Spark sum:`setSumAggOverflowCheckFlag(false)`(`SumAggregate.cpp:224`)→ `Overflow=true` 分支 → 静默回绕。 +> - JIT Spark sum:整数走 `CreateAdd`(`SumOps.cpp:46-47`)→ 静默回绕。 +> - **结论**:两者完全一致(都静默回绕),Spark 下无差异。同事建议的 `BOLT_CHECK(Overflow==true)` 仅为防止未来误改全局 flag 的护栏,属可选 P3。 + +- 非 JIT 默认 `CHECK_ADD` 抛异常;Spark 在 `registerSumAggregate` 显式 `setSumAggOverflowCheckFlag(false)` → `Overflow=true` → 静默回绕(`SumAggregateBase.h:190-197`、`SumAggregate.cpp:222`)。 +- JIT 永远 `CreateAdd` 静默回绕(`SumOps.cpp:46-47`)。 +- 当前只有 Spark 注册了 `supportsHashAggrJit`,**结论一致**。建议 JIT 入口加一条 `BOLT_CHECK(Overflow)` 防止后续误改全局 flag 让 Presto 路径也复用 JIT。 + +--- + +## C. 输入类型覆盖矩阵(按算子) + +| 算子 | 非 JIT 支持 | JIT 支持 | 不一致点 | +|---|---|---|---| +| `avg` | numeric + short/long decimal | numeric only(decimal 走另一条) | OK,gate 正确(`AverageAggregate.cpp:53-54` raw decimal 显式 false) | +| `count(*)/(col)` | 全部 | numeric + short/long decimal + hugeint | OK | +| `sum` | 同 avg | 非 decimal numeric + hugeint | OK | +| `min/max` | 全部 | numeric + short/long decimal + hugeint | OK;long decimal extract 走非 JIT (B6) | +| `decimal_sum` | short/long decimal raw + `ROW` intermediate | 同左 | ✅ B1/B4 已修 | +| `decimal_avg` | short/long decimal raw + `ROW` intermediate | 同左 | ✅ B2/B4 已修 | + +--- + +## D. 最终判定(Verdict) + +> **【复核 @ 当前工作区】** 下表反映原审计;当前状态:累加器布局一致性已升级为"单一权威类型"(B3 解决);B1/B2 的 Partial/Final extract 短 decimal 崩溃已修;B4 的 decimal partial merge row-field kind 漂移已修。剩余 B5/B7 均为 P3 加固项。 + +| 维度 | 一致? | +|---|---| +| Per-group 累加器结构体字节布局 | ✅ 全部一致 | +| 初值 + null bit 语义 | ✅ 一致(Presto MinMax 已校对;Spark MinMax NaN 排序待验) | +| Update 单点累加语义 | ✅ 一致(Spark 整数 sum 在 `Overflow=true` 下也一致) | +| Merge intermediate 语义 | ✅ B4 已修:decimal sum/avg partial row sum 字段按 `precision` 推导 kind | +| Partial extract → ROW 输出 | ✅ B1/B2 已修:decimal sum/avg 在 codegen 期选择 short/long 专用 helper | +| Final extract → 标量输出 | ✅ B1 已修:decimal sum final 在 codegen 期选择 short/long 专用 helper | +| 类型覆盖矩阵 | ✅ decimal 短/长结果不再依赖回退规避 crash | + +--- + +## E. 最小修复清单(按优先级) + +> **【复核 @ 当前工作区:当前优先级总览】** +> - **已解决**:B1 / B2 —— runtime helper 短 decimal 崩溃;B4 —— decimal partial row 输入 stride 漂移。 +> - **P3**:B5 / B7 —— 防御性测试/断言加固。 +> - **已解决/过时(无需再做)**:B3(已用单一权威布局根除)、B6(canExtract 已删、Int128/Bool extract 已支持)。 + +1. **B1/B2 修复(已完成)** + `emitDecimalSumExtract` / `emitDecimalAvgExtract` 已按 partial/final 的实际输出精度选择 short/long 专用 runtime helper,避免把 `longDecimal` 作为外部 C++ helper 参数导致 runtime 内保留无效分支。 + +2. **B3 修复(廉价护栏)** + 在 `HashAggrJitDecimalState.h` 同时 include 两侧头(`DecimalSumAggregate.h` / `DecimalAggregate.h` / `AverageAggregateBase.h`),加 `static_assert(sizeof, offsetof)` 跨层断言。AvgState 同理(与 `SumCount` cross-check)。 + +3. **B4 修复(已完成最小闭环)** + decimal partial merge 的 sum 字段已按 `precision` 推导为 `Int64/Int128`,并用该 kind 做 row-field read 和 cast。后续若要泛化到所有 ROW 字段,可再考虑 `HashAggrJitDescriptor.rowInputFields[i].kind`。 + +4. **B5 加固** + 补一份 Spark MinMax NaN-排序的对照测试(`max(NaN, 5.0)` / `min(NaN, 5.0)` JIT vs 非 JIT 必须完全一致),目前只校对了 Presto。 + +5. **B7 加固** + JIT 整数 sum slot 入口加 `BOLT_CHECK(Overflow == true)`,防止后续被静默改坏。 + +--- + +## 附录:关键文件一览 + +| 路径 | 作用 | +|---|---| +| `bolt/jit/aggregation/HashAggrJit.{h,cpp}` | JIT 主框架、IR codegen、runtime 装载 | +| `bolt/jit/aggregation/HashAggrJitTypes.h` | `HashAggrJitDescriptor` / `HashAggrJitSlot` / 输入输出 runtime 结构体 | +| `bolt/jit/aggregation/HashAggrJitDecimalState.h` | `JitDecimalSumState` / `JitDecimalAvgState`(**缺 cross-assert**) | +| `bolt/jit/aggregation/ops/{Avg,Count,Sum,MinMax,DecimalSum,DecimalAvg}Ops.cpp` | 各算子的 init/update/merge/extract 编译规则 | +| `bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp` | decimal sum/avg extract 的 C++ 运行时 helper(B1/B2 已修) | +| `bolt/exec/GroupingSet.cpp` | JIT chunk 调度、runtime fill、回退判断 | +| `bolt/functions/sparksql/aggregates/DecimalSumAggregate.h` | `DecimalSum` 非 JIT 结构 + `supportsHashAggrJit` | +| `bolt/functions/lib/aggregates/DecimalAggregate.h` | `LongDecimalWithOverflowState` 非 JIT 结构 | +| `bolt/functions/lib/aggregates/AverageAggregateBase.h` | `SumCount` 非 JIT 结构 | +| `bolt/functions/lib/aggregates/SumAggregateBase.h` | 整数 sum `CHECK_ADD` 与全局 `Overflow` flag | +| `bolt/functions/sparksql/aggregates/{SumAggregate,AverageAggregate}.cpp` | Spark sum/avg 注册 + JIT gate | +| `bolt/functions/prestosql/aggregates/{MinMaxAggregates,CountAggregate}.cpp` | Presto MinMax/Count 注册 + JIT gate | From 34df9b8a653c44262fa40f67c634f6463da66947 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sat, 13 Jun 2026 19:10:49 +0800 Subject: [PATCH 73/98] fix failed uts --- bolt/jit/aggregation/HashAggrJit.cpp | 15 +++++- doc/hash_aggr_jit_state_consistency_review.md | 46 +++++++++---------- 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 1bde63ba2..933af9cef 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -419,7 +419,18 @@ void emitFlatScalarValue( auto* bit = value->getType()->isIntegerTy(1) ? value : builder.CreateICmpNE(value, builder.getInt8(0)); - emitOutputNullBit(builder, values, row, bit); + auto* wordTy = builder.getInt64Ty(); + auto* wordIndex = builder.CreateLShr(row, builder.getInt32(6)); + auto* bitIndex = builder.CreateAnd(row, builder.getInt32(63)); + auto* words = builder.CreatePointerCast(values, wordTy->getPointerTo()); + auto* wordAddr = builder.CreateInBoundsGEP( + wordTy, words, builder.CreateZExt(wordIndex, builder.getInt64Ty())); + auto* word = builder.CreateLoad(wordTy, wordAddr); + auto* mask = builder.CreateShl( + builder.getInt64(1), builder.CreateZExt(bitIndex, builder.getInt64Ty())); + auto* trueWord = builder.CreateOr(word, mask); + auto* falseWord = builder.CreateAnd(word, builder.CreateNot(mask)); + builder.CreateStore(builder.CreateSelect(bit, trueWord, falseWord), wordAddr); return; } auto* type = llvmType(builder, kind); @@ -527,7 +538,7 @@ isInputNull(llvm::IRBuilder<>& builder, llvm::Value* nulls, llvm::Value* row) { builder.CreateInBoundsGEP( i64Ty, nullWords, builder.CreateZExt(wordIndex, builder.getInt64Ty()))); auto* shifted = builder.CreateLShr(word, builder.CreateZExt(bitIndex, i64Ty)); - return builder.CreateICmpNE( + return builder.CreateICmpEQ( builder.CreateAnd(shifted, builder.getInt64(1)), builder.getInt64(0)); } diff --git a/doc/hash_aggr_jit_state_consistency_review.md b/doc/hash_aggr_jit_state_consistency_review.md index 8473bd5d0..873d81a32 100644 --- a/doc/hash_aggr_jit_state_consistency_review.md +++ b/doc/hash_aggr_jit_state_consistency_review.md @@ -9,12 +9,12 @@ > - **B6 已过时**:Int128/Bool extract 已支持、`canCompileMinMaxExtract` 及整个 `CanExtractFn` 已删除。 > - **B1/B2 已解决**:decimal sum/avg extract 已在 codegen 期按实际输出精度选择 short/long 专用 runtime helper,runtime 内不再保留 `longDecimal` 分支;`precision/scale` 与 `auxPrecision/auxScale` 已统一为 intermediate/final 语义。 > - **B4 已解决(最小闭环)**:decimal sum/avg merge partial row 的 sum 字段读取 kind 已改为按 `precision` 推导,不再复用原始输入列 `inputKind`。 -> - **B5/B7 仍有效**:仅为防回归测试/断言加固类,P3。 +> - **B5/B7 已确认无问题**:B5 的 NaN 排序语义已确认 JIT/非 JIT 一致;B7 的 Spark 整数 sum 溢出语义也已确认一致,均无需继续处理。 --- ## 结论一句话 -累加器**字节布局**已通过单一权威 POD 布局对齐;`decimal_sum` / `decimal_avg` 的短/长 decimal extract 与 partial merge 位宽判断也已修复。当前剩余项主要是 B5/B7 这类防回归测试或断言加固,不再有已知 P0/P1 的 JIT/非 JIT 状态不一致问题。 +累加器**字节布局**已通过单一权威 POD 布局对齐;`decimal_sum` / `decimal_avg` 的短/长 decimal extract 与 partial merge 位宽判断也已修复;B5/B7 经复核确认不是问题。当前不再有已知 JIT/非 JIT 状态不一致待修项。 --- @@ -55,11 +55,11 @@ rowVector->childAt(0)->asFlatVector() // partial 同上 历史问题:`HashAggrDecimalRuntime.cpp` 曾硬写 `asFlatVector`。Spark AVG 第二条签名 `ROW(DECIMAL(a_precision, a_scale), BIGINT)` 会沿用入参精度——短 decimal 时 partial 输出是 `int64` sum vector,旧实现会 crash。当前已通过 short/long 专用 helper 消除该风险,且 runtime 内无无效分支。 -### 🟡 B3. JIT/非 JIT 结构没有跨层 `static_assert`(Major) +### ✅ B3. JIT/非 JIT 结构没有跨层 `static_assert`(已解决) > **【复核 @ `b4b99b5553`:已解决(更优方式),无需再加 static_assert】** 现已抽出零依赖 POD 布局基类(`DecimalAccumulatorLayout.h` 的 `DecimalSumAccumulatorLayout`/`LongDecimalWithOverflowLayout`、`SumCount.h` 的 `SumCount`):非 JIT 结构 `DecimalSum`/`LongDecimalWithOverflowState` **继承**之,JIT 端 `JitDecimalSumState`/`JitDecimalAvgState`/`AvgAccumulatorLayout` 用 `using` **别名同一基类**。两侧已是**同一类型**而非镜像副本,sizeof/offsetof 必然相等,cross-assert 已多余;各处保留 `static_assert(is_standard_layout_v)` 防止派生类误加数据成员破坏布局。 -`HashAggrJitDecimalState.h:28-29` 只断言了 `is_standard_layout`,缺: +历史建议曾要求在 `HashAggrJitDecimalState.h` 加跨层 `sizeof/offsetof` 断言: ```cpp static_assert(sizeof(JitDecimalSumState) == sizeof(sparksql::DecimalSum)); static_assert(offsetof(JitDecimalSumState, sum) == offsetof(sparksql::DecimalSum, sum)); @@ -67,7 +67,7 @@ static_assert(offsetof(JitDecimalSumState, sum) == offsetof(sparksql::DecimalSum // JitDecimalAvgState vs LongDecimalWithOverflowState 同理; // AvgAccumulatorLayout vs SumCount 同理。 ``` -ABI 完全靠手工同步,加 4 行最便宜也最实在。 +当前已通过共享 POD 布局基类让两侧复用同一类型,跨层 `sizeof/offsetof` 断言不再是必须项。 ### ✅ B4. row 输入 stride 仍按 plan 端 `slot.desc.inputKind`(已解决) @@ -75,9 +75,9 @@ ABI 完全靠手工同步,加 4 行最便宜也最实在。 历史问题:`DecimalSumOps.cpp` / `DecimalAvgOps.cpp` 读 row field 曾用 `slot.desc.inputKind`;runtime `fillHashAggrJitRowInputRuntime` 又按 vector 真实类型再反推一次。两侧不一致时 stride 错。当前在不扩 descriptor 的前提下,先利用已有 `precision` 作为 codegen 期 single source of truth,消除了 decimal partial row 的位宽漂移。 -### 🟢 B5. `MinAggregate` 初值不同(语义等价,但容易看错) +### ✅ B5. `MinAggregate` 初值不同(已确认无问题) -> **【复核 @ `b4b99b5553`:已验证 JIT == 非 JIT,Spark 下也一致;仅需补防回归测试】** +> **【最终确认:无问题,无需继续处理】** 已确认 JIT == 非 JIT,Spark/Presto 下语义一致;补防回归测试属于可选工程加固,不再作为待办项。 > 注意前提:Spark 与 Presto 的 min/max **共用同一份 `registerMinMax` + 同一个 `MinAggregate`/`MaxAggregate` + 同一份 JIT op**,所以确实需要验,但结论是一致的。 > - 非 JIT 权威比较是 `SimpleVector::comparePrimitiveAsc`(`SimpleVector.h:368-380`):**NaN 视为最大**(NaN 排在所有非 NaN 之后),且该语义**不随 `SPARK_COMPATIBLE` 改变**——Spark/Presto 统一。 > - JIT op(`MinMaxOps.cpp:45-59`)逐组合等价于 NaN=最大: @@ -85,11 +85,11 @@ ABI 完全靠手工同步,加 4 行最便宜也最实在。 > - Max:`!oldIsNan && (valueIsNan || old - 对 {NaN,非NaN} 全部四种组合手工核对,结果与 `comparePrimitiveAsc` 完全一致。 > - 初值 `0.0`(JIT)/ `NaN`(非 JIT Presto)都不参与比较:首条非 null 输入必定 `nullState=true` 无条件覆盖初值(`shouldStore = nullState || better`)。 -> - **结论**:B5 不是 bug,语义在 Spark 与 Presto 下均一致。剩余价值仅为**防回归**:补 `max(NaN,5.0)`/`min(NaN,5.0)` 的 JIT vs 非 JIT 对照用例钉死等价性,优先级 P3。 +> - **结论**:B5 不是 bug,语义在 Spark 与 Presto 下均一致。无需修复。 - 非 JIT Presto: `kInitialValue_ = NaN` (`MinMaxAggregates.cpp:367-371`) - JIT: 统一写 `0.0` (`MinMaxOps.cpp:21`);靠 `shouldStore = nullState || better` 让第一条非 null 输入无条件覆盖 -- 我手算了所有 NaN/非 NaN 组合,Presto 下结果一致;**但 Spark MinMax 的 NaN 排序语义和 Presto 不同**,如果 Spark 也走同一份 JIT op,需要再验。 +- 已确认所有 NaN/非 NaN 组合下 JIT 与非 JIT 一致;Spark/Presto 共用的比较语义与 JIT op 匹配。 ### 🟢 B6. `MinMax` 混合路径 @@ -97,17 +97,17 @@ ABI 完全靠手工同步,加 4 行最便宜也最实在。 `canCompileMinMaxExtract` (`MinMaxOps.cpp:74-79`) 对 `Int128`/`Bool` 返回 `false`,extract 走非 JIT,init/update 走 JIT。当前 `slot.nullByte/nullMask` 来自 `RowContainer::nullByte/Mask`(`GroupingSet.cpp:766`),和 `exec::Aggregate::isNull` 一致——OK,但建议加一条 NOTE 注释防止后续重构改 null 槽布局踩坑。 -### 🟢 B7. 整数 `sum` 溢出 +### ✅ B7. 整数 `sum` 溢出(已确认无问题) -> **【复核 @ `b4b99b5553`:Spark 语义下不是 bug,结论一致】** +> **【最终确认:无问题,无需继续处理】** Spark 语义下不是 bug,JIT/非 JIT 结论一致。 > - 走 JIT 的整数 sum **只有 Spark**:Presto 的 sum 未注册 `supportsHashAggrJit`(prestosql 下仅 Count/MinMax 接入 JIT),因此不存在"Presto sum 复用 JIT"的实际路径。注意 sum 与 min/max 不同——min/max 是 Spark/Presto 共用注册,sum 各自独立注册(Spark 有自己的 `registerSumAggregate`)。 > - 非 JIT Spark sum:`setSumAggOverflowCheckFlag(false)`(`SumAggregate.cpp:224`)→ `Overflow=true` 分支 → 静默回绕。 > - JIT Spark sum:整数走 `CreateAdd`(`SumOps.cpp:46-47`)→ 静默回绕。 -> - **结论**:两者完全一致(都静默回绕),Spark 下无差异。同事建议的 `BOLT_CHECK(Overflow==true)` 仅为防止未来误改全局 flag 的护栏,属可选 P3。 +> - **结论**:两者完全一致(都静默回绕),Spark 下无差异;不需要加额外修复。 - 非 JIT 默认 `CHECK_ADD` 抛异常;Spark 在 `registerSumAggregate` 显式 `setSumAggOverflowCheckFlag(false)` → `Overflow=true` → 静默回绕(`SumAggregateBase.h:190-197`、`SumAggregate.cpp:222`)。 - JIT 永远 `CreateAdd` 静默回绕(`SumOps.cpp:46-47`)。 -- 当前只有 Spark 注册了 `supportsHashAggrJit`,**结论一致**。建议 JIT 入口加一条 `BOLT_CHECK(Overflow)` 防止后续误改全局 flag 让 Presto 路径也复用 JIT。 +- 当前只有 Spark 注册了 `supportsHashAggrJit`,**结论一致**,无需继续处理。 --- @@ -126,12 +126,12 @@ ABI 完全靠手工同步,加 4 行最便宜也最实在。 ## D. 最终判定(Verdict) -> **【复核 @ 当前工作区】** 下表反映原审计;当前状态:累加器布局一致性已升级为"单一权威类型"(B3 解决);B1/B2 的 Partial/Final extract 短 decimal 崩溃已修;B4 的 decimal partial merge row-field kind 漂移已修。剩余 B5/B7 均为 P3 加固项。 +> **【复核 @ 当前工作区】** 下表反映原审计;当前状态:累加器布局一致性已升级为"单一权威类型"(B3 解决);B1/B2 的 Partial/Final extract 短 decimal 崩溃已修;B4 的 decimal partial merge row-field kind 漂移已修;B5/B7 已确认无问题。当前无已知一致性待修项。 | 维度 | 一致? | |---|---| | Per-group 累加器结构体字节布局 | ✅ 全部一致 | -| 初值 + null bit 语义 | ✅ 一致(Presto MinMax 已校对;Spark MinMax NaN 排序待验) | +| 初值 + null bit 语义 | ✅ 一致(MinMax NaN 排序已确认一致) | | Update 单点累加语义 | ✅ 一致(Spark 整数 sum 在 `Overflow=true` 下也一致) | | Merge intermediate 语义 | ✅ B4 已修:decimal sum/avg partial row sum 字段按 `precision` 推导 kind | | Partial extract → ROW 输出 | ✅ B1/B2 已修:decimal sum/avg 在 codegen 期选择 short/long 专用 helper | @@ -144,23 +144,23 @@ ABI 完全靠手工同步,加 4 行最便宜也最实在。 > **【复核 @ 当前工作区:当前优先级总览】** > - **已解决**:B1 / B2 —— runtime helper 短 decimal 崩溃;B4 —— decimal partial row 输入 stride 漂移。 -> - **P3**:B5 / B7 —— 防御性测试/断言加固。 +> - **已确认无问题**:B5 / B7 —— JIT/非 JIT 语义一致,无需继续处理。 > - **已解决/过时(无需再做)**:B3(已用单一权威布局根除)、B6(canExtract 已删、Int128/Bool extract 已支持)。 1. **B1/B2 修复(已完成)** `emitDecimalSumExtract` / `emitDecimalAvgExtract` 已按 partial/final 的实际输出精度选择 short/long 专用 runtime helper,避免把 `longDecimal` 作为外部 C++ helper 参数导致 runtime 内保留无效分支。 -2. **B3 修复(廉价护栏)** - 在 `HashAggrJitDecimalState.h` 同时 include 两侧头(`DecimalSumAggregate.h` / `DecimalAggregate.h` / `AverageAggregateBase.h`),加 `static_assert(sizeof, offsetof)` 跨层断言。AvgState 同理(与 `SumCount` cross-check)。 +2. **B3 修复(已完成)** + 已通过 `DecimalAccumulatorLayout.h` / `SumCount.h` 抽出共享 POD 布局基类,JIT 与非 JIT 复用同一权威布局类型,无需再补跨层 `sizeof/offsetof` 断言。 3. **B4 修复(已完成最小闭环)** decimal partial merge 的 sum 字段已按 `precision` 推导为 `Int64/Int128`,并用该 kind 做 row-field read 和 cast。后续若要泛化到所有 ROW 字段,可再考虑 `HashAggrJitDescriptor.rowInputFields[i].kind`。 -4. **B5 加固** - 补一份 Spark MinMax NaN-排序的对照测试(`max(NaN, 5.0)` / `min(NaN, 5.0)` JIT vs 非 JIT 必须完全一致),目前只校对了 Presto。 +4. **B5(已确认无问题)** + MinMax NaN 排序已确认 JIT/非 JIT 一致,无需修复。 -5. **B7 加固** - JIT 整数 sum slot 入口加 `BOLT_CHECK(Overflow == true)`,防止后续被静默改坏。 +5. **B7(已确认无问题)** + Spark 整数 sum 溢出语义已确认 JIT/非 JIT 一致,无需修复。 --- @@ -170,7 +170,7 @@ ABI 完全靠手工同步,加 4 行最便宜也最实在。 |---|---| | `bolt/jit/aggregation/HashAggrJit.{h,cpp}` | JIT 主框架、IR codegen、runtime 装载 | | `bolt/jit/aggregation/HashAggrJitTypes.h` | `HashAggrJitDescriptor` / `HashAggrJitSlot` / 输入输出 runtime 结构体 | -| `bolt/jit/aggregation/HashAggrJitDecimalState.h` | `JitDecimalSumState` / `JitDecimalAvgState`(**缺 cross-assert**) | +| `bolt/jit/aggregation/HashAggrJitDecimalState.h` | `JitDecimalSumState` / `JitDecimalAvgState`(已通过共享 POD 布局基类消除镜像漂移) | | `bolt/jit/aggregation/ops/{Avg,Count,Sum,MinMax,DecimalSum,DecimalAvg}Ops.cpp` | 各算子的 init/update/merge/extract 编译规则 | | `bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp` | decimal sum/avg extract 的 C++ 运行时 helper(B1/B2 已修) | | `bolt/exec/GroupingSet.cpp` | JIT chunk 调度、runtime fill、回退判断 | From 64ae0c337fcb61709b2bf1929317648db2edd835 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 14 Jun 2026 16:09:00 +0800 Subject: [PATCH 74/98] part1: refactor HashAggrJitDescriptor --- bolt/exec/GroupingSet.cpp | 26 ++++++----- .../prestosql/aggregates/CountAggregate.cpp | 25 ++++++----- .../prestosql/aggregates/MinMaxAggregates.cpp | 18 ++++---- .../sparksql/aggregates/AverageAggregate.cpp | 44 +++++++++---------- .../sparksql/aggregates/DecimalSumAggregate.h | 22 +++++----- .../sparksql/aggregates/SumAggregate.cpp | 14 +++--- bolt/jit/aggregation/HashAggrJit.cpp | 40 +++++++++++++---- bolt/jit/aggregation/HashAggrJit.h | 1 - bolt/jit/aggregation/HashAggrJitTypes.h | 20 ++++++--- bolt/jit/aggregation/ops/AvgOps.cpp | 1 - bolt/jit/aggregation/ops/CountOps.cpp | 3 +- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 7 ++- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 7 ++- bolt/jit/aggregation/ops/MinMaxOps.cpp | 1 - bolt/jit/aggregation/ops/SumOps.cpp | 1 - 15 files changed, 131 insertions(+), 99 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index e778c9a69..d822751a4 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -250,9 +250,10 @@ std::string hashAggrJitSlotDebugString( << " accKind=" << jit::hashAggrJitValueKindName(slot.desc.accumulatorKind) << " offset=" << slot.offset << " nullByte=" << slot.nullByte << " nullMask=" << static_cast(slot.nullMask) - << " countStar=" << slot.desc.countStar - << " mergeInput=" << slot.desc.mergeInput << " decimal=" << slot.desc.decimal - << " ops=" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "null"); + << " countStar=" << slot.desc.isCountStar() + << " mergeInput=" << !slot.desc.isRawInput() + << " decimal=" << slot.desc.decimal + << " kindName=" << jit::hashAggrJitKindName(slot.desc.kind); return out.str(); } @@ -284,20 +285,21 @@ std::optional makeHashAggrJitSlot( return std::nullopt; } - const int32_t inputCount = aggregate.inputs.size(); - if (!(isRawInput && inputCount == 0) && inputCount != 1) { + std::vector inputTypes; + if (isRawInput) { + inputTypes = aggregate.rawInputTypes; + } else { + inputTypes = {aggregate.intermediateType}; + } + + if (!(isRawInput && inputTypes.empty()) && inputTypes.size() != 1) { return std::nullopt; } - const auto inputType = - inputCount == 0 ? nullptr - : (isRawInput ? aggregate.rawInputTypes[0] - : aggregate.intermediateType); const jit::HashAggrJitPlanContext context{ .isRawInput = isRawInput, .isPartialOutput = isPartialOutput, - .inputCount = inputCount, - .inputType = inputType}; + .inputTypes = std::move(inputTypes)}; if (!aggregate.function->supportsHashAggrJit(context)) { return std::nullopt; } @@ -1174,7 +1176,7 @@ void GroupingSet::runHashAggrJitAddChunks( skipReason = "selectivity vector is not dense activeRows or has no selections"; break; } - if (slot.desc.countStar) { + if (slot.desc.isCountStar()) { continue; } if (aggregate.inputs.size() != 1) { diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 34b533e8d..d859273cf 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -57,14 +57,19 @@ class CountAggregate : public SimpleNumericAggregate { bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { if (context.isRawInput) { - return context.inputCount == 0 || - (context.inputCount == 1 && context.inputType != nullptr && - !context.inputType->isRow() && - (context.inputType->isDecimal() || - jit::isHashAggrJitSupportedType(context.inputType->kind()))); + if (context.isCountStar()) { + return true; + } + if (context.inputTypes.size() != 1 || context.inputTypes[0] == nullptr) { + return false; + } + const auto& inputType = context.inputTypes[0]; + return !inputType->isRow() && + (inputType->isDecimal() || + jit::isHashAggrJitSupportedType(inputType->kind())); } - return context.inputCount == 1 && context.inputType != nullptr && - context.inputType->kind() == TypeKind::BIGINT; + return context.inputTypes.size() == 1 && context.inputTypes[0] != nullptr && + context.inputTypes[0]->kind() == TypeKind::BIGINT; } std::optional createHashAggrJitDescriptor( @@ -74,7 +79,8 @@ class CountAggregate : public SimpleNumericAggregate { } auto inputKind = jit::HashAggrJitValueKind::Int64; if (!context.isCountStar()) { - auto maybeInputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + auto maybeInputKind = + jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); if (!maybeInputKind.has_value()) { return std::nullopt; } @@ -84,8 +90,7 @@ class CountAggregate : public SimpleNumericAggregate { .kind = jit::HashAggrJitKind::Count, .inputKind = inputKind, .accumulatorKind = jit::HashAggrJitValueKind::Int64, - .countStar = context.isCountStar(), - .mergeInput = !context.isRawInput, + .context = context, .decimal = false, .inputShape = jit::HashAggrJitRuntimeShape::Scalar, .outputShape = jit::HashAggrJitRuntimeShape::Scalar, diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 364960d4d..1416bd573 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -55,11 +55,14 @@ class MinMaxAggregate : public SimpleNumericAggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - return context.inputCount == 1 && context.inputType != nullptr && - !context.inputType->isRow() && - (context.inputType->isDecimal() || - jit::isHashAggrJitSupportedType(context.inputType->kind()) || - context.inputType->kind() == TypeKind::HUGEINT); + if (context.inputTypes.size() != 1 || context.inputTypes[0] == nullptr) { + return false; + } + const auto& inputType = context.inputTypes[0]; + return !inputType->isRow() && + (inputType->isDecimal() || + jit::isHashAggrJitSupportedType(inputType->kind()) || + inputType->kind() == TypeKind::HUGEINT); } std::optional createHashAggrJitDescriptor( @@ -67,7 +70,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { if (!supportsHashAggrJit(context)) { return std::nullopt; } - auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); if (!inputKind.has_value()) { return std::nullopt; } @@ -75,8 +78,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { .kind = jitKind(), .inputKind = *inputKind, .accumulatorKind = *inputKind, - .countStar = false, - .mergeInput = !context.isRawInput, + .context = context, .decimal = false, .inputShape = jit::HashAggrJitRuntimeShape::Scalar, .outputShape = jit::HashAggrJitRuntimeShape::Scalar, diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 344e48315..7d2de2b92 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -46,19 +46,20 @@ class AverageAggregate #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputCount != 1 || !context.inputType) { + if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { return false; } + const auto& inputType = context.inputTypes[0]; if (context.isRawInput) { - if (context.inputType->isDecimal()) { + if (inputType->isDecimal()) { return false; } - return jit::isHashAggrJitSupportedType(context.inputType->kind()) || - context.inputType->kind() == TypeKind::HUGEINT; + return jit::isHashAggrJitSupportedType(inputType->kind()) || + inputType->kind() == TypeKind::HUGEINT; } - return context.inputType->isRow() && context.inputType->size() == 2 && - context.inputType->childAt(1)->kind() == TypeKind::BIGINT && - context.inputType->childAt(0)->kind() == TypeKind::DOUBLE; + return inputType->isRow() && inputType->size() == 2 && + inputType->childAt(1)->kind() == TypeKind::BIGINT && + inputType->childAt(0)->kind() == TypeKind::DOUBLE; } std::optional createHashAggrJitDescriptor( @@ -72,8 +73,7 @@ class AverageAggregate .kind = jit::HashAggrJitKind::Avg, .inputKind = jit::HashAggrJitValueKind::Double, .accumulatorKind = jit::HashAggrJitValueKind::Double, - .countStar = false, - .mergeInput = true, + .context = context, .decimal = false, .inputShape = jit::HashAggrJitRuntimeShape::Row, .outputShape = context.isPartialOutput @@ -86,7 +86,7 @@ class AverageAggregate .ops = jit::getAvgOps()}; } - auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); if (!inputKind.has_value()) { return std::nullopt; } @@ -94,8 +94,7 @@ class AverageAggregate .kind = jit::HashAggrJitKind::Avg, .inputKind = *inputKind, .accumulatorKind = jit::HashAggrJitValueKind::Double, - .countStar = false, - .mergeInput = false, + .context = context, .decimal = false, .inputShape = jit::HashAggrJitRuntimeShape::Scalar, .outputShape = context.isPartialOutput @@ -163,17 +162,16 @@ class DecimalAverageAggregate : public DecimalAggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputCount != 1 || !context.inputType) { + if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { return false; } + const auto& inputType = context.inputTypes[0]; if (context.isRawInput) { - return context.inputType->isDecimal() && - (context.inputType->isShortDecimal() || - context.inputType->isLongDecimal()); + return inputType->isDecimal(); } - return context.inputType->isRow() && context.inputType->size() == 2 && - context.inputType->childAt(0)->isDecimal() && - context.inputType->childAt(1)->kind() == TypeKind::BIGINT; + return inputType->isRow() && inputType->size() == 2 && + inputType->childAt(0)->isDecimal() && + inputType->childAt(1)->kind() == TypeKind::BIGINT; } std::optional createHashAggrJitDescriptor( @@ -181,21 +179,21 @@ class DecimalAverageAggregate : public DecimalAggregate { if (!supportsHashAggrJit(context)) { return std::nullopt; } + const auto& inputType = context.inputTypes[0]; const auto& valueType = - context.isRawInput ? context.inputType : context.inputType->childAt(0); + context.isRawInput ? inputType : inputType->childAt(0); const auto [sumPrecision, sumScale] = getDecimalPrecisionScale(*sumType_.get()); const auto [resultPrecision, resultScale] = context.isPartialOutput ? std::pair{0, 0} : getDecimalPrecisionScale(*this->resultType().get()); return jit::HashAggrJitDescriptor{ - .kind = jit::HashAggrJitKind::Avg, + .kind = jit::HashAggrJitKind::DecimalAvg, .inputKind = valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 : jit::HashAggrJitValueKind::Int128, .accumulatorKind = jit::HashAggrJitValueKind::Int128, - .countStar = false, - .mergeInput = !context.isRawInput, + .context = context, .decimal = true, .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar : jit::HashAggrJitRuntimeShape::Row, diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 16429c78f..d852acb72 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -63,17 +63,17 @@ class DecimalSumAggregate : public exec::Aggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputCount != 1 || !context.inputType) { + if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { return false; } + const auto& inputType = context.inputTypes[0]; if (context.isRawInput) { - return context.inputType->isDecimal() && - (context.inputType->isShortDecimal() || - context.inputType->isLongDecimal()); + return inputType->isDecimal() && + (inputType->isShortDecimal() || inputType->isLongDecimal()); } - return context.inputType->isRow() && context.inputType->size() == 2 && - context.inputType->childAt(0)->isDecimal() && - context.inputType->childAt(1)->kind() == TypeKind::BOOLEAN; + return inputType->isRow() && inputType->size() == 2 && + inputType->childAt(0)->isDecimal() && + inputType->childAt(1)->kind() == TypeKind::BOOLEAN; } std::optional createHashAggrJitDescriptor( @@ -81,8 +81,9 @@ class DecimalSumAggregate : public exec::Aggregate { if (!supportsHashAggrJit(context)) { return std::nullopt; } + const auto& inputType = context.inputTypes[0]; const auto& valueType = - context.isRawInput ? context.inputType : context.inputType->childAt(0); + context.isRawInput ? inputType : inputType->childAt(0); // Unified decimal precision/scale convention across decimal aggregates: // precision/scale -> intermediate (partial) decimal type // auxPrecision/auxScale -> final result decimal type @@ -91,13 +92,12 @@ class DecimalSumAggregate : public exec::Aggregate { const auto [sumPrecision, sumScale] = getDecimalPrecisionScale(*sumType_.get()); return jit::HashAggrJitDescriptor{ - .kind = jit::HashAggrJitKind::Sum, + .kind = jit::HashAggrJitKind::DecimalSum, .inputKind = valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 : jit::HashAggrJitValueKind::Int128, .accumulatorKind = jit::HashAggrJitValueKind::Int128, - .countStar = false, - .mergeInput = !context.isRawInput, + .context = context, .decimal = true, .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar : jit::HashAggrJitRuntimeShape::Row, diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index 35b3a3950..b147efcf8 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -46,14 +46,15 @@ class SumAggregate : public SumAggregateBase { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputCount != 1 || !context.inputType) { + if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { return false; } - if (context.inputType->isRow() || context.inputType->isDecimal()) { + const auto& inputType = context.inputTypes[0]; + if (inputType->isRow() || inputType->isDecimal()) { return false; } - return jit::isHashAggrJitSupportedType(context.inputType->kind()) || - context.inputType->kind() == TypeKind::HUGEINT; + return jit::isHashAggrJitSupportedType(inputType->kind()) || + inputType->kind() == TypeKind::HUGEINT; } std::optional createHashAggrJitDescriptor( @@ -62,7 +63,7 @@ class SumAggregate : public SumAggregateBase { return std::nullopt; } - auto inputKind = jit::hashAggrJitValueKind(context.inputType->kind()); + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); if (!inputKind.has_value()) { return std::nullopt; } @@ -77,8 +78,7 @@ class SumAggregate : public SumAggregateBase { .kind = jit::HashAggrJitKind::Sum, .inputKind = *inputKind, .accumulatorKind = accumulatorKind, - .countStar = false, - .mergeInput = !context.isRawInput, + .context = context, .decimal = false, .inputShape = jit::HashAggrJitRuntimeShape::Scalar, .outputShape = jit::HashAggrJitRuntimeShape::Scalar, diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 933af9cef..4475144e1 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -961,7 +961,7 @@ bool genAddDenseIR( input = std::make_unique(codegen, inputRuntime); } - if (checkInputNulls && !slot.desc.countStar) { + if (checkInputNulls && !slot.desc.isCountStar()) { auto* nulls = input->loadNulls(); auto* nullCheckBlock = llvm::BasicBlock::Create(context, "slot_null_check", func, end); @@ -981,8 +981,9 @@ bool genAddDenseIR( if (slot.desc.ops == nullptr) { return false; } - auto* addFn = - slot.desc.mergeInput ? slot.desc.ops->addIntermediateResults : slot.desc.ops->addRawInput; + auto* addFn = !slot.desc.isRawInput() + ? slot.desc.ops->addIntermediateResults + : slot.desc.ops->addRawInput; if (addFn == nullptr) { return false; } @@ -1079,13 +1080,14 @@ HashAggrJitChunk::HashAggrJitChunk( out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" << slots_.size(); for (const auto& slot : slots_) { - out << "_" << (slot.desc.ops != nullptr ? slot.desc.ops->id : "unknown") - << "_" << static_cast(slot.desc.kind) + out << "_" << hashAggrJitKindName(slot.desc.kind) << "_" + << static_cast(slot.desc.kind) << hashAggrJitValueKindName(slot.desc.inputKind) << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" << slot.offset << "n" << slot.nullByte << "m" - << static_cast(slot.nullMask) << (slot.desc.countStar ? "s" : "x") - << (slot.desc.mergeInput ? "g" : "r") + << static_cast(slot.nullMask) + << (slot.desc.isCountStar() ? "s" : "x") + << (!slot.desc.isRawInput() ? "g" : "r") << (slot.desc.decimal ? "d" : "n") << "i" << hashAggrJitRuntimeShapeName(slot.desc.inputShape) << "o" << hashAggrJitRuntimeShapeName(slot.desc.outputShape); @@ -1097,6 +1099,26 @@ HashAggrJitChunk::HashAggrJitChunk( extractFunctionName_ = functionName_ + "_extract"; } +std::string hashAggrJitKindName(HashAggrJitKind kind) { + switch (kind) { + case HashAggrJitKind::Count: + return "count"; + case HashAggrJitKind::Sum: + return "sum"; + case HashAggrJitKind::DecimalSum: + return "decimal_sum"; + case HashAggrJitKind::Min: + return "min"; + case HashAggrJitKind::Max: + return "max"; + case HashAggrJitKind::Avg: + return "avg"; + case HashAggrJitKind::DecimalAvg: + return "decimal_avg"; + } + return "unknown"; +} + std::string hashAggrJitValueKindName(HashAggrJitValueKind kind) { switch (kind) { case HashAggrJitValueKind::Bool: @@ -1161,11 +1183,11 @@ bool isHashAggrJitSupportedType(TypeKind kind) { std::string HashAggrJitDescriptor::signature() const { return fmt::format( "{}_{}_{}_{}_{}_{}_{}_{}", - ops != nullptr ? ops->id : "unknown", + hashAggrJitKindName(kind), static_cast(kind), hashAggrJitValueKindName(inputKind), hashAggrJitValueKindName(accumulatorKind), - mergeInput, + !isRawInput(), decimal, hashAggrJitRuntimeShapeName(inputShape), hashAggrJitRuntimeShapeName(outputShape)); diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 1361e0c09..a4aa114c8 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -39,7 +39,6 @@ struct HashAggrJitOps { const HashAggrJitSlot&, const HashAggrJitExtractTarget&); - const char* id; CreateFn initGroup; AddFn addRawInput; AddFn addIntermediateResults; diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index 201c02a4c..0194c1d21 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "bolt/type/Type.h" @@ -81,20 +82,21 @@ union HashAggrJitOutputRuntime { struct HashAggrJitPlanContext { bool isRawInput{false}; bool isPartialOutput{false}; - int32_t inputCount{0}; - TypePtr inputType; + std::vector inputTypes; bool isCountStar() const { - return isRawInput && inputCount == 0; + return isRawInput && inputTypes.empty(); } }; enum class HashAggrJitKind : uint8_t { Count, Sum, + DecimalSum, Min, Max, Avg, + DecimalAvg, }; enum class HashAggrJitValueKind : uint8_t { @@ -123,8 +125,7 @@ struct HashAggrJitDescriptor { HashAggrJitKind kind; HashAggrJitValueKind inputKind; HashAggrJitValueKind accumulatorKind; - bool countStar{false}; - bool mergeInput{false}; + HashAggrJitPlanContext context; bool decimal{false}; HashAggrJitRuntimeShape inputShape{HashAggrJitRuntimeShape::Scalar}; HashAggrJitRuntimeShape outputShape{HashAggrJitRuntimeShape::Scalar}; @@ -138,6 +139,14 @@ struct HashAggrJitDescriptor { int32_t auxScale{0}; const HashAggrJitOps* ops{nullptr}; + bool isCountStar() const { + return context.isCountStar(); + } + + bool isRawInput() const { + return context.isRawInput; + } + // std::string signature() const; }; @@ -153,6 +162,7 @@ struct HashAggrJitSlot { bool isHashAggrJitSupportedType(TypeKind kind); std::optional hashAggrJitValueKind(TypeKind kind); +std::string hashAggrJitKindName(HashAggrJitKind kind); std::string hashAggrJitValueKindName(HashAggrJitValueKind kind); // Per-aggregate codegen function tables. The definitions (which reference diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index e8d061d1a..aa065a68c 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -147,7 +147,6 @@ void compileAvgExtractValues( const HashAggrJitOps* getAvgOps() { static const HashAggrJitOps kOps{ - "avg", &compileAvgInitGroup, &compileAvgAddRawInput, &compileAvgAddIntermediateResults, diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp index fe6b05679..06f2a3d45 100644 --- a/bolt/jit/aggregation/ops/CountOps.cpp +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -54,7 +54,7 @@ void compileCountAddIntermediateResults( const HashAggrJitSlot& slot, llvm::BasicBlock*) { llvm::Value* inc = nullptr; - if (slot.desc.countStar) { + if (slot.desc.isCountStar()) { inc = codegen.builder().getInt64(1); } else { auto* inputRow = input.read(row, slot.desc.inputKind); @@ -102,7 +102,6 @@ void compileCountExtractValues( const HashAggrJitOps* getCountOps() { static const HashAggrJitOps kOps{ - "count", &compileCountInitGroup, &compileCountAddRawInput, &compileCountAddIntermediateResults, diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index 754423701..c1fc62724 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -84,16 +84,16 @@ void compileDecimalAvgAddIntermediateResults( auto* function = b.GetInsertBlock()->getParent(); auto* continueBlock = llvm::BasicBlock::Create( codegen.module().getContext(), - "avg_decimal_merge_cont", + "decimal_avg_merge_cont", function, nextBlock); auto* overflowBlock = llvm::BasicBlock::Create( codegen.module().getContext(), - "avg_decimal_merge_overflow", + "decimal_avg_merge_overflow", function, continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), "avg_decimal_merge", function, continueBlock); + codegen.module().getContext(), "decimal_avg_merge", function, continueBlock); const auto sumKind = decimalKindForPrecision(slot.desc.precision); auto* sumRow = input.readRowField(row, 0, sumKind); auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); @@ -192,7 +192,6 @@ void compileDecimalAvgExtractValues( const HashAggrJitOps* getDecimalAvgOps() { static const HashAggrJitOps kOps{ - "avg_decimal", &compileDecimalAvgInitGroup, &compileDecimalAvgAddRawInput, &compileDecimalAvgAddIntermediateResults, diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index f9a772a5c..e2bba3281 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -82,16 +82,16 @@ void compileDecimalSumAddIntermediateResults( auto* function = b.GetInsertBlock()->getParent(); auto* continueBlock = llvm::BasicBlock::Create( codegen.module().getContext(), - "sum_decimal_merge_cont", + "decimal_sum_merge_cont", function, nextBlock); auto* overflowBlock = llvm::BasicBlock::Create( codegen.module().getContext(), - "sum_decimal_merge_overflow", + "decimal_sum_merge_overflow", function, continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( - codegen.module().getContext(), "sum_decimal_merge", function, continueBlock); + codegen.module().getContext(), "decimal_sum_merge", function, continueBlock); const auto sumKind = decimalKindForPrecision(slot.desc.precision); auto* sumRow = input.readRowField(row, 0, sumKind); auto* incomingIsEmpty = @@ -193,7 +193,6 @@ void compileDecimalSumExtractValues( const HashAggrJitOps* getDecimalSumOps() { static const HashAggrJitOps kOps{ - "sum_decimal", &compileDecimalSumInitGroup, &compileDecimalSumAddRawInput, &compileDecimalSumAddIntermediateResults, diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp index 9dfa04ed7..058e53e70 100644 --- a/bolt/jit/aggregation/ops/MinMaxOps.cpp +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -108,7 +108,6 @@ void compileMinMaxExtractValues( const HashAggrJitOps* getMinMaxOps() { static const HashAggrJitOps kOps{ - "minmax", &compileMinMaxInitGroup, &compileMinMaxUpdate, &compileMinMaxUpdate, diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp index 95cc236fa..92c6621a3 100644 --- a/bolt/jit/aggregation/ops/SumOps.cpp +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -85,7 +85,6 @@ void compileSumExtractValues( const HashAggrJitOps* getSumOps() { static const HashAggrJitOps kOps{ - "sum", &compileSumInitGroup, &compileSumAccumulate, &compileSumAccumulate, From 35f2b4c7846ceb4fa91b8223b0a6c6f9d7489d86 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 14 Jun 2026 20:14:40 +0800 Subject: [PATCH 75/98] part2: refactor HashAggrJitDescriptor --- bolt/exec/GroupingSet.cpp | 20 +++++----- .../prestosql/aggregates/CountAggregate.cpp | 9 +---- .../prestosql/aggregates/MinMaxAggregates.cpp | 9 +---- .../sparksql/aggregates/AverageAggregate.cpp | 39 ++----------------- .../sparksql/aggregates/DecimalSumAggregate.h | 17 +------- .../sparksql/aggregates/SumAggregate.cpp | 9 +---- bolt/jit/aggregation/HashAggrJit.cpp | 22 +++++------ bolt/jit/aggregation/HashAggrJitTypes.h | 32 +++++++++------ bolt/jit/aggregation/ops/AvgOps.cpp | 6 +-- bolt/jit/aggregation/ops/CountOps.cpp | 4 +- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 37 ++++++++---------- bolt/jit/aggregation/ops/DecimalOps.h | 18 +++++++++ bolt/jit/aggregation/ops/DecimalSumOps.cpp | 33 +++++++--------- bolt/jit/aggregation/ops/MinMaxOps.cpp | 4 +- bolt/jit/aggregation/ops/SumOps.cpp | 4 +- 15 files changed, 104 insertions(+), 159 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index d822751a4..6798f4808 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -154,7 +154,7 @@ bool fillHashAggrJitRowInputRuntime( // hot merge path; populating its raw pointer lets generated IR load it // directly instead of calling the per-row jit_GetDecodedRowField* helper // (which rebuilds a field DecodedVector on every call). - if (slot.desc.inputShape != jit::HashAggrJitRuntimeShape::Row) { + if (slot.desc.inputShape() != jit::HashAggrJitRuntimeShape::Row) { return false; } const auto* base = decoded.base(); @@ -246,13 +246,13 @@ std::string hashAggrJitSlotDebugString( out << "]"; } out << " kind=" << static_cast(slot.desc.kind) - << " inputKind=" << jit::hashAggrJitValueKindName(slot.desc.inputKind) + << " inputKind=" << jit::hashAggrJitValueKindName(slot.desc.rawInputKind) << " accKind=" << jit::hashAggrJitValueKindName(slot.desc.accumulatorKind) << " offset=" << slot.offset << " nullByte=" << slot.nullByte << " nullMask=" << static_cast(slot.nullMask) << " countStar=" << slot.desc.isCountStar() << " mergeInput=" << !slot.desc.isRawInput() - << " decimal=" << slot.desc.decimal + << " decimal=" << slot.desc.isDecimal() << " kindName=" << jit::hashAggrJitKindName(slot.desc.kind); return out.str(); } @@ -292,14 +292,12 @@ std::optional makeHashAggrJitSlot( inputTypes = {aggregate.intermediateType}; } - if (!(isRawInput && inputTypes.empty()) && inputTypes.size() != 1) { - return std::nullopt; - } - const jit::HashAggrJitPlanContext context{ .isRawInput = isRawInput, .isPartialOutput = isPartialOutput, - .inputTypes = std::move(inputTypes)}; + .inputTypes = std::move(inputTypes), + .outputType = isPartialOutput ? aggregate.intermediateType + : aggregate.function->resultType()}; if (!aggregate.function->supportsHashAggrJit(context)) { return std::nullopt; } @@ -1199,7 +1197,7 @@ void GroupingSet::runHashAggrJitAddChunks( hashAggrJitInputVectors_[slotIndex] = arg; hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); const bool usesRowInputRuntime = - slot.desc.inputShape == jit::HashAggrJitRuntimeShape::Row; + slot.desc.inputShape() == jit::HashAggrJitRuntimeShape::Row; if (usesRowInputRuntime) { if (!fillHashAggrJitRowInputRuntime( hashAggrJitInputRuntimes_[slotIndex], @@ -1290,7 +1288,7 @@ void GroupingSet::runHashAggrJitExtractChunks( auto& aggregateVector = result->childAt(slot.aggregateIndex + aggregateOutputOffset); const auto expectedEncoding = - slot.desc.outputShape == jit::HashAggrJitRuntimeShape::Row + slot.desc.outputShape() == jit::HashAggrJitRuntimeShape::Row ? VectorEncoding::Simple::ROW : VectorEncoding::Simple::FLAT; if (aggregateVector->encoding() != expectedEncoding) { @@ -1323,7 +1321,7 @@ void GroupingSet::runHashAggrJitExtractChunks( .vector = aggregateVector.get()}; } else if ( aggregateVector->encoding() == VectorEncoding::Simple::ROW && - slot.desc.outputShape == jit::HashAggrJitRuntimeShape::Row) { + slot.desc.outputShape() == jit::HashAggrJitRuntimeShape::Row) { if (!fillHashAggrJitRowOutputRuntime( hashAggrJitOutputRuntimes_[slotIndex], hashAggrJitRowOutputChildren_[slotIndex], diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index d859273cf..fa2e8839c 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -88,16 +88,9 @@ class CountAggregate : public SimpleNumericAggregate { } return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::Count, - .inputKind = inputKind, + .rawInputKind = inputKind, .accumulatorKind = jit::HashAggrJitValueKind::Int64, .context = context, - .decimal = false, - .inputShape = jit::HashAggrJitRuntimeShape::Scalar, - .outputShape = jit::HashAggrJitRuntimeShape::Scalar, - .precision = 0, - .scale = 0, - .auxPrecision = 0, - .auxScale = 0, .ops = jit::getCountOps()}; } #endif diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index 1416bd573..f3b1ca220 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -76,16 +76,9 @@ class MinMaxAggregate : public SimpleNumericAggregate { } return jit::HashAggrJitDescriptor{ .kind = jitKind(), - .inputKind = *inputKind, + .rawInputKind = *inputKind, .accumulatorKind = *inputKind, .context = context, - .decimal = false, - .inputShape = jit::HashAggrJitRuntimeShape::Scalar, - .outputShape = jit::HashAggrJitRuntimeShape::Scalar, - .precision = 0, - .scale = 0, - .auxPrecision = 0, - .auxScale = 0, .ops = jit::getMinMaxOps()}; } diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 7d2de2b92..12c06591e 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -71,18 +71,9 @@ class AverageAggregate if (!context.isRawInput) { return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::Avg, - .inputKind = jit::HashAggrJitValueKind::Double, + .rawInputKind = jit::HashAggrJitValueKind::Double, .accumulatorKind = jit::HashAggrJitValueKind::Double, .context = context, - .decimal = false, - .inputShape = jit::HashAggrJitRuntimeShape::Row, - .outputShape = context.isPartialOutput - ? jit::HashAggrJitRuntimeShape::Row - : jit::HashAggrJitRuntimeShape::Scalar, - .precision = 0, - .scale = 0, - .auxPrecision = 0, - .auxScale = 0, .ops = jit::getAvgOps()}; } @@ -92,18 +83,9 @@ class AverageAggregate } return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::Avg, - .inputKind = *inputKind, + .rawInputKind = *inputKind, .accumulatorKind = jit::HashAggrJitValueKind::Double, .context = context, - .decimal = false, - .inputShape = jit::HashAggrJitRuntimeShape::Scalar, - .outputShape = context.isPartialOutput - ? jit::HashAggrJitRuntimeShape::Row - : jit::HashAggrJitRuntimeShape::Scalar, - .precision = 0, - .scale = 0, - .auxPrecision = 0, - .auxScale = 0, .ops = jit::getAvgOps()}; } #endif @@ -182,28 +164,13 @@ class DecimalAverageAggregate : public DecimalAggregate { const auto& inputType = context.inputTypes[0]; const auto& valueType = context.isRawInput ? inputType : inputType->childAt(0); - const auto [sumPrecision, sumScale] = - getDecimalPrecisionScale(*sumType_.get()); - const auto [resultPrecision, resultScale] = context.isPartialOutput - ? std::pair{0, 0} - : getDecimalPrecisionScale(*this->resultType().get()); return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::DecimalAvg, - .inputKind = valueType->isShortDecimal() + .rawInputKind = valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 : jit::HashAggrJitValueKind::Int128, .accumulatorKind = jit::HashAggrJitValueKind::Int128, .context = context, - .decimal = true, - .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar - : jit::HashAggrJitRuntimeShape::Row, - .outputShape = context.isPartialOutput - ? jit::HashAggrJitRuntimeShape::Row - : jit::HashAggrJitRuntimeShape::Scalar, - .precision = sumPrecision, - .scale = sumScale, - .auxPrecision = resultPrecision, - .auxScale = resultScale, .ops = jit::getDecimalAvgOps()}; } #endif diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index d852acb72..730d738e6 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -84,28 +84,13 @@ class DecimalSumAggregate : public exec::Aggregate { const auto& inputType = context.inputTypes[0]; const auto& valueType = context.isRawInput ? inputType : inputType->childAt(0); - // Unified decimal precision/scale convention across decimal aggregates: - // precision/scale -> intermediate (partial) decimal type - // auxPrecision/auxScale -> final result decimal type - // For decimal sum the intermediate and final decimal type are the same - // (both sumType_), so aux* mirror precision/scale. - const auto [sumPrecision, sumScale] = - getDecimalPrecisionScale(*sumType_.get()); return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::DecimalSum, - .inputKind = valueType->isShortDecimal() + .rawInputKind = valueType->isShortDecimal() ? jit::HashAggrJitValueKind::Int64 : jit::HashAggrJitValueKind::Int128, .accumulatorKind = jit::HashAggrJitValueKind::Int128, .context = context, - .decimal = true, - .inputShape = context.isRawInput ? jit::HashAggrJitRuntimeShape::Scalar - : jit::HashAggrJitRuntimeShape::Row, - .outputShape = jit::HashAggrJitRuntimeShape::Scalar, - .precision = sumPrecision, - .scale = sumScale, - .auxPrecision = sumPrecision, - .auxScale = sumScale, .ops = jit::getDecimalSumOps()}; } #endif diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index b147efcf8..a92deebc5 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -76,16 +76,9 @@ class SumAggregate : public SumAggregateBase { return jit::HashAggrJitDescriptor{ .kind = jit::HashAggrJitKind::Sum, - .inputKind = *inputKind, + .rawInputKind = *inputKind, .accumulatorKind = accumulatorKind, .context = context, - .decimal = false, - .inputShape = jit::HashAggrJitRuntimeShape::Scalar, - .outputShape = jit::HashAggrJitRuntimeShape::Scalar, - .precision = 0, - .scale = 0, - .auxPrecision = 0, - .auxScale = 0, .ops = jit::getSumOps()}; } #endif diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 4475144e1..8fcd7330b 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -848,11 +848,11 @@ char hashAggrJitRuntimeShapeName(HashAggrJitRuntimeShape shape) { } bool usesRowInputRuntime(const HashAggrJitSlot& slot) { - return slot.desc.inputShape == HashAggrJitRuntimeShape::Row; + return slot.desc.inputShape() == HashAggrJitRuntimeShape::Row; } -bool usesRowOutputRuntime(const HashAggrJitSlot& slot, bool partialOutput) { - return partialOutput && slot.desc.outputShape == HashAggrJitRuntimeShape::Row; +bool usesRowOutputRuntime(const HashAggrJitSlot& slot) { + return slot.desc.outputShape() == HashAggrJitRuntimeShape::Row; } bool genAddDenseIR( @@ -1045,7 +1045,7 @@ bool genExtractIR( auto* outputAddr = builder.CreateConstInBoundsGEP1_64(i8PtrTy, resultVectors, i); auto* outputRuntime = builder.CreateLoad(i8PtrTy, outputAddr); std::unique_ptr output; - if (usesRowOutputRuntime(slot, partialOutput)) { + if (usesRowOutputRuntime(slot)) { output = std::make_unique(codegen, outputRuntime); } else { output = @@ -1082,15 +1082,15 @@ HashAggrJitChunk::HashAggrJitChunk( for (const auto& slot : slots_) { out << "_" << hashAggrJitKindName(slot.desc.kind) << "_" << static_cast(slot.desc.kind) - << hashAggrJitValueKindName(slot.desc.inputKind) + << hashAggrJitValueKindName(slot.desc.rawInputKind) << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" << slot.offset << "n" << slot.nullByte << "m" << static_cast(slot.nullMask) << (slot.desc.isCountStar() ? "s" : "x") << (!slot.desc.isRawInput() ? "g" : "r") - << (slot.desc.decimal ? "d" : "n") << "i" - << hashAggrJitRuntimeShapeName(slot.desc.inputShape) << "o" - << hashAggrJitRuntimeShapeName(slot.desc.outputShape); + << (slot.desc.isDecimal() ? "d" : "n") << "i" + << hashAggrJitRuntimeShapeName(slot.desc.inputShape()) << "o" + << hashAggrJitRuntimeShapeName(slot.desc.outputShape()); } functionName_ = out.str(); initFunctionName_ = functionName_ + "_init"; @@ -1188,9 +1188,9 @@ std::string HashAggrJitDescriptor::signature() const { hashAggrJitValueKindName(inputKind), hashAggrJitValueKindName(accumulatorKind), !isRawInput(), - decimal, - hashAggrJitRuntimeShapeName(inputShape), - hashAggrJitRuntimeShapeName(outputShape)); + isDecimal(), + hashAggrJitRuntimeShapeName(inputShape()), + hashAggrJitRuntimeShapeName(outputShape())); } */ diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index 0194c1d21..21b36e5cb 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -83,6 +83,7 @@ struct HashAggrJitPlanContext { bool isRawInput{false}; bool isPartialOutput{false}; std::vector inputTypes; + TypePtr outputType; bool isCountStar() const { return isRawInput && inputTypes.empty(); @@ -123,20 +124,9 @@ struct HashAggrJitOps; struct HashAggrJitDescriptor { HashAggrJitKind kind; - HashAggrJitValueKind inputKind; + HashAggrJitValueKind rawInputKind; HashAggrJitValueKind accumulatorKind; HashAggrJitPlanContext context; - bool decimal{false}; - HashAggrJitRuntimeShape inputShape{HashAggrJitRuntimeShape::Scalar}; - HashAggrJitRuntimeShape outputShape{HashAggrJitRuntimeShape::Scalar}; - // Result decimal precision/scale, used by decimal extract overflow checks. - // Only meaningful when decimal == true. - int32_t precision{0}; - int32_t scale{0}; - // Secondary decimal precision/scale. For decimal avg extract, precision/scale - // carry the intermediate sum type and aux* carry the result type. - int32_t auxPrecision{0}; - int32_t auxScale{0}; const HashAggrJitOps* ops{nullptr}; bool isCountStar() const { @@ -147,6 +137,24 @@ struct HashAggrJitDescriptor { return context.isRawInput; } + bool isDecimal() const { + return kind == HashAggrJitKind::DecimalSum || + kind == HashAggrJitKind::DecimalAvg; + } + + HashAggrJitRuntimeShape inputShape() const { + return context.inputTypes.size() == 1 && context.inputTypes[0] && + context.inputTypes[0]->isRow() + ? HashAggrJitRuntimeShape::Row + : HashAggrJitRuntimeShape::Scalar; + } + + HashAggrJitRuntimeShape outputShape() const { + return context.outputType && context.outputType->isRow() + ? HashAggrJitRuntimeShape::Row + : HashAggrJitRuntimeShape::Scalar; + } + // std::string signature() const; }; diff --git a/bolt/jit/aggregation/ops/AvgOps.cpp b/bolt/jit/aggregation/ops/AvgOps.cpp index aa065a68c..98463f188 100644 --- a/bolt/jit/aggregation/ops/AvgOps.cpp +++ b/bolt/jit/aggregation/ops/AvgOps.cpp @@ -48,10 +48,10 @@ void compileAvgAddRawInput( llvm::Value* row, const HashAggrJitSlot& slot, llvm::BasicBlock*) { - auto* inputRow = input.read(row, slot.desc.inputKind); + auto* inputRow = input.read(row, slot.desc.rawInputKind); auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); - auto* value = - codegen.castValue(rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, slot.desc.accumulatorKind); codegen.clearAccumulatorNull(group, slot); auto* oldSum = codegen.loadValue( group, codegen.llvmType(slot.desc.accumulatorKind), slot.offset); diff --git a/bolt/jit/aggregation/ops/CountOps.cpp b/bolt/jit/aggregation/ops/CountOps.cpp index 06f2a3d45..cc43c20ea 100644 --- a/bolt/jit/aggregation/ops/CountOps.cpp +++ b/bolt/jit/aggregation/ops/CountOps.cpp @@ -57,10 +57,10 @@ void compileCountAddIntermediateResults( if (slot.desc.isCountStar()) { inc = codegen.builder().getInt64(1); } else { - auto* inputRow = input.read(row, slot.desc.inputKind); + auto* inputRow = input.read(row, slot.desc.rawInputKind); inc = codegen.castValue( IRRow::getValue(codegen.builder(), inputRow), - slot.desc.inputKind, + slot.desc.rawInputKind, HashAggrJitValueKind::Int64); } addInc(codegen, group, slot, inc); diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index c1fc62724..b7dfad7f4 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -24,12 +24,6 @@ constexpr int32_t kCountOffset = constexpr int32_t kOverflowOffset = static_cast(offsetof(JitDecimalAvgState, overflow)); -HashAggrJitValueKind decimalKindForPrecision(int32_t precision) { - return precision > bytedance::bolt::ShortDecimalType::kMaxPrecision - ? HashAggrJitValueKind::Int128 - : HashAggrJitValueKind::Int64; -} - void compileDecimalAvgInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -56,10 +50,10 @@ void compileDecimalAvgAddRawInput( const HashAggrJitSlot& slot, llvm::BasicBlock*) { auto& b = codegen.builder(); - auto* inputRow = input.read(row, slot.desc.inputKind); + auto* inputRow = input.read(row, slot.desc.rawInputKind); auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); - auto* value = - codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); emitDecimalAddWithOverflow( codegen, group, slot.offset + kSumOffset, slot.offset + kOverflowOffset, value); @@ -94,7 +88,9 @@ void compileDecimalAvgAddIntermediateResults( continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "decimal_avg_merge", function, continueBlock); - const auto sumKind = decimalKindForPrecision(slot.desc.precision); + const auto [sumPrecision, _] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes[0]); + const auto sumKind = hashAggrJitDecimalKindForPrecision(sumPrecision); auto* sumRow = input.readRowField(row, 0, sumKind); auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); auto* sumIsNull = IRRow::getIsNull(b, sumRow); @@ -136,13 +132,12 @@ void emitDecimalAvgExtract( const HashAggrJitSlot& slot, bool partialOutput) { auto& b = codegen.builder(); - // long/short decimal of the written sum column: partial output writes the - // intermediate sum decimal (precision/scale); final output writes the result - // decimal (auxPrecision/auxScale). - const int32_t outPrecision = - partialOutput ? slot.desc.precision : slot.desc.auxPrecision; - const bool longDecimal = - decimalKindForPrecision(outPrecision) == HashAggrJitValueKind::Int128; + const auto [inputPrecision, inputScale] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes[0]); + const auto [outputPrecision, outputScale] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType); + const bool longDecimal = hashAggrJitDecimalKindForPrecision( + outputPrecision) == HashAggrJitValueKind::Int128; const char* fn = partialOutput ? (longDecimal ? "jit_HashAggrExtractPartialLongDecimalAvg" : "jit_HashAggrExtractPartialShortDecimalAvg") @@ -154,10 +149,10 @@ void emitDecimalAvgExtract( row, group, b.getInt32(slot.offset), - b.getInt32(slot.desc.precision), - b.getInt32(slot.desc.scale), - b.getInt32(slot.desc.auxPrecision), - b.getInt32(slot.desc.auxScale)}); + b.getInt32(inputPrecision), + b.getInt32(inputScale), + b.getInt32(outputPrecision), + b.getInt32(outputScale)}); } void compileDecimalAvgExtractAccumulators( diff --git a/bolt/jit/aggregation/ops/DecimalOps.h b/bolt/jit/aggregation/ops/DecimalOps.h index 423a5e289..ba5a3a5e7 100644 --- a/bolt/jit/aggregation/ops/DecimalOps.h +++ b/bolt/jit/aggregation/ops/DecimalOps.h @@ -17,6 +17,24 @@ // file-local in their respective ops files. namespace bytedance::bolt::jit { +inline HashAggrJitValueKind hashAggrJitDecimalKindForPrecision( + int32_t precision) { + return precision > bytedance::bolt::ShortDecimalType::kMaxPrecision + ? HashAggrJitValueKind::Int128 + : HashAggrJitValueKind::Int64; +} + +inline const TypePtr& hashAggrJitDecimalValueType(const TypePtr& type) { + return type->isRow() ? type->childAt(0) : type; +} + +inline std::pair hashAggrJitDecimalPrecisionScale( + const TypePtr& type) { + const auto [precision, scale] = + getDecimalPrecisionScale(*hashAggrJitDecimalValueType(type)); + return {precision, scale}; +} + // Inline i128 accumulate-with-overflow used by decimal sum/avg add+merge. // Loads the i128 sum at 'group + sumOffset' and the i64 overflow counter at // 'group + overflowOffset', computes sum += addend, updates the overflow diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index e2bba3281..2c7da8399 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -24,12 +24,6 @@ constexpr int32_t kOverflowOffset = constexpr int32_t kIsEmptyOffset = static_cast(offsetof(JitDecimalSumState, isEmpty)); -HashAggrJitValueKind decimalKindForPrecision(int32_t precision) { - return precision > bytedance::bolt::ShortDecimalType::kMaxPrecision - ? HashAggrJitValueKind::Int128 - : HashAggrJitValueKind::Int64; -} - void compileDecimalSumInitGroup( HashAggrJitCodegen& codegen, llvm::Value* group, @@ -56,10 +50,10 @@ void compileDecimalSumAddRawInput( const HashAggrJitSlot& slot, llvm::BasicBlock*) { auto& b = codegen.builder(); - auto* inputRow = input.read(row, slot.desc.inputKind); + auto* inputRow = input.read(row, slot.desc.rawInputKind); auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); - auto* value = - codegen.castValue(rawValue, slot.desc.inputKind, HashAggrJitValueKind::Int128); + auto* value = codegen.castValue( + rawValue, slot.desc.rawInputKind, HashAggrJitValueKind::Int128); codegen.clearAccumulatorNull(group, slot); emitDecimalAddWithOverflow( codegen, @@ -92,7 +86,9 @@ void compileDecimalSumAddIntermediateResults( continueBlock); auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "decimal_sum_merge", function, continueBlock); - const auto sumKind = decimalKindForPrecision(slot.desc.precision); + const auto [sumPrecision, _] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes[0]); + const auto sumKind = hashAggrJitDecimalKindForPrecision(sumPrecision); auto* sumRow = input.readRowField(row, 0, sumKind); auto* incomingIsEmpty = input.readRowFieldValue(row, 1, HashAggrJitValueKind::Bool); @@ -139,13 +135,12 @@ void emitDecimalSumExtract( const HashAggrJitSlot& slot, bool partialOutput) { auto& b = codegen.builder(); - // long/short decimal is decided by the actual output decimal type, not the - // input kind: precision/scale carry the intermediate (partial) decimal type, - // auxPrecision/auxScale carry the final result decimal type. - const int32_t outPrecision = - partialOutput ? slot.desc.precision : slot.desc.auxPrecision; - const bool longDecimal = - decimalKindForPrecision(outPrecision) == HashAggrJitValueKind::Int128; + // long/short decimal and overflow precision are decided by the actual + // output decimal type of this aggregation stage. + const auto [outPrecision, outScale] = + hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType); + const bool longDecimal = hashAggrJitDecimalKindForPrecision(outPrecision) == + HashAggrJitValueKind::Int128; const char* fn = partialOutput ? (longDecimal ? "jit_HashAggrExtractPartialLongDecimalSum" : "jit_HashAggrExtractPartialShortDecimalSum") @@ -157,8 +152,8 @@ void emitDecimalSumExtract( row, group, b.getInt32(slot.offset), - b.getInt32(slot.desc.precision), - b.getInt32(slot.desc.scale)}); + b.getInt32(outPrecision), + b.getInt32(outScale)}); } void compileDecimalSumExtractAccumulators( diff --git a/bolt/jit/aggregation/ops/MinMaxOps.cpp b/bolt/jit/aggregation/ops/MinMaxOps.cpp index 058e53e70..fd2440fae 100644 --- a/bolt/jit/aggregation/ops/MinMaxOps.cpp +++ b/bolt/jit/aggregation/ops/MinMaxOps.cpp @@ -33,10 +33,10 @@ void compileMinMaxUpdate( llvm::Value* row, const HashAggrJitSlot& slot, llvm::BasicBlock*) { - auto* inputRow = input.read(row, slot.desc.inputKind); + auto* inputRow = input.read(row, slot.desc.rawInputKind); auto* value = codegen.castValue( IRRow::getValue(codegen.builder(), inputRow), - slot.desc.inputKind, + slot.desc.rawInputKind, slot.desc.accumulatorKind); auto* type = codegen.llvmType(slot.desc.accumulatorKind); auto* oldValue = codegen.loadValue(group, type, slot.offset); diff --git a/bolt/jit/aggregation/ops/SumOps.cpp b/bolt/jit/aggregation/ops/SumOps.cpp index 92c6621a3..1215e7ba6 100644 --- a/bolt/jit/aggregation/ops/SumOps.cpp +++ b/bolt/jit/aggregation/ops/SumOps.cpp @@ -35,10 +35,10 @@ void compileSumAccumulate( llvm::Value* row, const HashAggrJitSlot& slot, llvm::BasicBlock*) { - auto* inputRow = input.read(row, slot.desc.inputKind); + auto* inputRow = input.read(row, slot.desc.rawInputKind); auto* rawValue = IRRow::getValue(codegen.builder(), inputRow); auto* value = codegen.castValue( - rawValue, slot.desc.inputKind, slot.desc.accumulatorKind); + rawValue, slot.desc.rawInputKind, slot.desc.accumulatorKind); auto* accType = codegen.llvmType(slot.desc.accumulatorKind); codegen.clearAccumulatorNull(group, slot); auto* oldValue = codegen.loadValue(group, accType, slot.offset); From 026d6e36003aa1921fe83db9378ca8d5d4740cfa Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 14 Jun 2026 21:17:15 +0800 Subject: [PATCH 76/98] part3: refine HashAggrJitChunk function naming Derive chunk function names from compact hashed slot descriptions and avoid storing derived function-name members. Simplify JIT planning debug output to reuse slot and chunk descriptions. --- bolt/exec/GroupingSet.cpp | 96 ++++++------------------- bolt/jit/aggregation/HashAggrJit.cpp | 88 +++++++++++------------ bolt/jit/aggregation/HashAggrJit.h | 26 ++----- bolt/jit/aggregation/HashAggrJitTypes.h | 5 +- 4 files changed, 75 insertions(+), 140 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 6798f4808..28d5a9db7 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -229,50 +229,6 @@ bool fillHashAggrJitRowOutputRuntime( return true; } -std::string hashAggrJitSlotDebugString( - const jit::HashAggrJitSlot& slot, - const AggregateInfo* aggregate = nullptr) { - std::ostringstream out; - out << "agg#" << slot.aggregateIndex; - if (aggregate != nullptr) { - out << "(" << aggregate->name << ")"; - out << " inputs=["; - for (size_t i = 0; i < aggregate->rawInputTypes.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << hashAggrJitTypeName(aggregate->rawInputTypes[i]); - } - out << "]"; - } - out << " kind=" << static_cast(slot.desc.kind) - << " inputKind=" << jit::hashAggrJitValueKindName(slot.desc.rawInputKind) - << " accKind=" << jit::hashAggrJitValueKindName(slot.desc.accumulatorKind) - << " offset=" << slot.offset << " nullByte=" << slot.nullByte - << " nullMask=" << static_cast(slot.nullMask) - << " countStar=" << slot.desc.isCountStar() - << " mergeInput=" << !slot.desc.isRawInput() - << " decimal=" << slot.desc.isDecimal() - << " kindName=" << jit::hashAggrJitKindName(slot.desc.kind); - return out.str(); -} - -std::string hashAggrJitChunkDebugString( - const jit::HashAggrJitChunk& chunk, - const std::vector& aggregates) { - std::ostringstream out; - out << chunk.functionName() << " slots=["; - for (size_t i = 0; i < chunk.slots().size(); ++i) { - if (i > 0) { - out << "; "; - } - const auto& slot = chunk.slots()[i]; - out << hashAggrJitSlotDebugString(slot, &aggregates[slot.aggregateIndex]); - } - out << "] canExtract=" << chunk.canExtract() - << " codegenReady=" << chunk.isCodegenReady(); - return out.str(); -} #endif std::optional makeHashAggrJitSlot( @@ -1057,27 +1013,19 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { auto flushChunk = [&]() { if (currentChunkSlots.size() < minChunkWidth) { if (!currentChunkSlots.empty()) { - std::ostringstream out; - for (size_t i = 0; i < currentChunkSlots.size(); ++i) { - if (i > 0) { - out << "; "; - } - const auto& slot = currentChunkSlots[i]; - out << hashAggrJitSlotDebugString(slot, &aggregates_[slot.aggregateIndex]); - } VLOG(1) << "HashAggrJit discard chunk candidate due to width " << currentChunkSlots.size() << " < " << minChunkWidth - << ": [" << out.str() << "]"; + << "."; } currentChunkSlots.clear(); return; } - jit::HashAggrJitChunk chunk(std::move(currentChunkSlots), isPartial_); + jit::HashAggrJitChunk chunk( + std::move(currentChunkSlots), isRawInput_, isPartial_); if (chunk.codegen()) { hashAggrJitChunks_.push_back(std::move(chunk)); VLOG(1) << "HashAggrJit formed chunk: " - << hashAggrJitChunkDebugString( - hashAggrJitChunks_.back(), aggregates_); + << hashAggrJitChunks_.back().getDescription(); } else { VLOG(1) << "HashAggrJit chunk codegen failed for chunk " << chunk.functionName(); @@ -1090,14 +1038,19 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); if (!slot.has_value()) { VLOG(1) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" - << aggregates_[i].name << ") rawInputTypes=[" + << aggregates_[i].name << ") isRawInput=" << isRawInput_ + << " isPartialOutput=" << isPartial_ << " inputTypes=[" << [&]() { std::ostringstream out; - for (size_t j = 0; j < aggregates_[i].rawInputTypes.size(); ++j) { - if (j > 0) { - out << ", "; + if (isRawInput_) { + for (size_t j = 0; j < aggregates_[i].rawInputTypes.size(); ++j) { + if (j > 0) { + out << ", "; + } + out << hashAggrJitTypeName(aggregates_[i].rawInputTypes[j]); } - out << hashAggrJitTypeName(aggregates_[i].rawInputTypes[j]); + } else { + out << hashAggrJitTypeName(aggregates_[i].intermediateType); } return out.str(); }() @@ -1105,8 +1058,10 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { << " mask=" << aggregates_[i].mask.has_value() << " sortingKeys=" << aggregates_[i].sortingKeys.size() << " inputs=" << aggregates_[i].inputs.size() - << " intermediateType=" - << hashAggrJitTypeName(aggregates_[i].intermediateType); + << " outputType=" + << hashAggrJitTypeName( + isPartial_ ? aggregates_[i].intermediateType + : aggregates_[i].function->resultType()); flushChunk(); continue; } @@ -1115,7 +1070,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { flushChunk(); } VLOG(1) << "HashAggrJit aggregate is JIT-able: " - << hashAggrJitSlotDebugString(*slot, &aggregates_[i]); + << slot->getDescription(); currentChunkSlots.push_back(*slot); } @@ -1144,7 +1099,7 @@ void GroupingSet::runHashAggrJitAddChunks( for (auto& chunk : hashAggrJitChunks_) { if (!chunk.isCodegenReady()) { VLOG(1) << "HashAggrJit chunk is not codegen-ready, skip add: " - << hashAggrJitChunkDebugString(chunk, aggregates_); + << chunk.getDescription(); continue; } @@ -1225,8 +1180,7 @@ void GroupingSet::runHashAggrJitAddChunks( if (!canRunChunk) { VLOG(1) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " - << hashAggrJitChunkDebugString(chunk, aggregates_) - << " reason=" << skipReason; + << chunk.getDescription() << " reason=" << skipReason; continue; } @@ -1264,11 +1218,6 @@ void GroupingSet::runHashAggrJitExtractChunks( jitExtracted.assign(aggregates_.size(), 0); for (auto& chunk : hashAggrJitChunks_) { - if (!chunk.canExtract()) { - VLOG(1) << "HashAggrJit chunk cannot extract, fallback to non-JIT extract: " - << hashAggrJitChunkDebugString(chunk, aggregates_); - continue; - } const auto numSlots = chunk.slots().size(); hashAggrJitOutputRuntimes_.assign(numSlots, jit::HashAggrJitOutputRuntime{}); hashAggrJitRowOutputChildren_.resize(numSlots); @@ -1337,8 +1286,7 @@ void GroupingSet::runHashAggrJitExtractChunks( } if (!canRunChunk) { VLOG(1) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " - << hashAggrJitChunkDebugString(chunk, aggregates_) - << " reason=" << skipReason; + << chunk.getDescription() << " reason=" << skipReason; continue; } chunk.extract(groups.data(), groups.size(), hashAggrJitResultPtrs_.data()); diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 8fcd7330b..79f5993af 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -14,6 +14,7 @@ #include +#include "bolt/common/base/BitUtil.h" #include "bolt/common/base/Exceptions.h" #include "bolt/jit/ThrustJITv2.h" @@ -1072,31 +1073,49 @@ bool genExtractIR( } // namespace +std::string HashAggrJitSlot::getDescription() const { + std::ostringstream inputs; + inputs << "["; + for (size_t i = 0; i < desc.context.inputTypes.size(); ++i) { + if (i > 0) { + inputs << ","; + } + inputs << desc.context.inputTypes[i]->toString(); + } + inputs << "]"; + + return fmt::format( + "{}_raw{}_partial{}({})->{}@{}", + hashAggrJitKindName(desc.kind), + desc.context.isRawInput, + desc.context.isPartialOutput, + inputs.str(), + desc.context.outputType->toString(), + offset); +} + HashAggrJitChunk::HashAggrJitChunk( std::vector slots, - bool partialOutput) - : slots_(std::move(slots)), partialOutput_(partialOutput) { + bool isRawInput, + bool isPartialOutput) + : slots_(std::move(slots)), + isRawInput_(isRawInput), + isPartialOutput_(isPartialOutput) { + const auto description = getDescription(); + functionName_ = fmt::format( + "jit_hashaggr_v2_raw{}_partial{}_n{}_h{:016x}", + isRawInput_, + isPartialOutput_, + slots_.size(), + bits::hashBytes(1, description.data(), description.size())); +} + +std::string HashAggrJitChunk::getDescription() const { std::ostringstream out; - out << "jit_hashaggr_v2_" << (partialOutput_ ? "partial" : "final") << "_n" - << slots_.size(); for (const auto& slot : slots_) { - out << "_" << hashAggrJitKindName(slot.desc.kind) << "_" - << static_cast(slot.desc.kind) - << hashAggrJitValueKindName(slot.desc.rawInputKind) - << hashAggrJitValueKindName(slot.desc.accumulatorKind) << "o" - << slot.offset << "n" << slot.nullByte << "m" - << static_cast(slot.nullMask) - << (slot.desc.isCountStar() ? "s" : "x") - << (!slot.desc.isRawInput() ? "g" : "r") - << (slot.desc.isDecimal() ? "d" : "n") << "i" - << hashAggrJitRuntimeShapeName(slot.desc.inputShape()) << "o" - << hashAggrJitRuntimeShapeName(slot.desc.outputShape()); + out << slot.getDescription() << ";"; } - functionName_ = out.str(); - initFunctionName_ = functionName_ + "_init"; - addDenseFunctionName_ = functionName_ + "_add_dense"; - addDenseNoNullFunctionName_ = functionName_ + "_add_dense_no_null"; - extractFunctionName_ = functionName_ + "_extract"; + return out.str(); } std::string hashAggrJitKindName(HashAggrJitKind kind) { @@ -1179,25 +1198,6 @@ bool isHashAggrJitSupportedType(TypeKind kind) { } } -/* -std::string HashAggrJitDescriptor::signature() const { - return fmt::format( - "{}_{}_{}_{}_{}_{}_{}_{}", - hashAggrJitKindName(kind), - static_cast(kind), - hashAggrJitValueKindName(inputKind), - hashAggrJitValueKindName(accumulatorKind), - !isRawInput(), - isDecimal(), - hashAggrJitRuntimeShapeName(inputShape()), - hashAggrJitRuntimeShapeName(outputShape())); -} -*/ - -bool HashAggrJitChunk::canExtract() const { - return extract_ != nullptr; -} - bool HashAggrJitChunk::codegen() { if (addDense_) { return true; @@ -1207,16 +1207,16 @@ bool HashAggrJitChunk::codegen() { return false; } const auto& moduleKey = functionName_; - const auto& initFn = initFunctionName_; - const auto& addFn = addDenseFunctionName_; - const auto& addNoNullFn = addDenseNoNullFunctionName_; - const auto& extractFn = extractFunctionName_; + const auto initFn = functionName_ + "_init"; + const auto addFn = functionName_ + "_add_dense"; + const auto addNoNullFn = functionName_ + "_add_dense_no_null"; + const auto extractFn = functionName_ + "_extract"; module_ = jit->CompileModule( [&](llvm::Module& module) { const bool ok = genInitIR(module, initFn, slots_) && genAddDenseIR(module, addFn, slots_, true) && genAddDenseIR(module, addNoNullFn, slots_, false) && - genExtractIR(module, extractFn, slots_, partialOutput_); + genExtractIR(module, extractFn, slots_, isPartialOutput_); const bool hasError = !ok; logHashAggrJitFunctionIR(module, moduleKey, initFn, "init", hasError); logHashAggrJitFunctionIR(module, moduleKey, addFn, "add_dense", hasError); diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index a4aa114c8..e528ebbf2 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -278,7 +278,8 @@ class HashAggrJitChunk { public: explicit HashAggrJitChunk( std::vector slots, - bool partialOutput = false); + bool isRawInput, + bool partialOutput); bool codegen(); @@ -286,8 +287,6 @@ class HashAggrJitChunk { return addDense_ != nullptr; } - bool canExtract() const; - void init(char** newGroups, int32_t numNewGroups) const { init_(newGroups, numNewGroups); } @@ -312,30 +311,17 @@ class HashAggrJitChunk { return slots_; } + std::string getDescription() const; + const std::string& functionName() const { return functionName_; } - const std::string& initFunctionName() const { - return initFunctionName_; - } - const std::string& addDenseFunctionName() const { - return addDenseFunctionName_; - } - const std::string& addDenseNoNullFunctionName() const { - return addDenseNoNullFunctionName_; - } - const std::string& extractFunctionName() const { - return extractFunctionName_; - } private: std::vector slots_; - bool partialOutput_{false}; + bool isRawInput_{false}; + bool isPartialOutput_{false}; std::string functionName_; - std::string initFunctionName_; - std::string addDenseFunctionName_; - std::string addDenseNoNullFunctionName_; - std::string extractFunctionName_; CompiledModuleSP module_; HashAggrJitInitFunc init_{nullptr}; HashAggrJitAddDenseFunc addDense_{nullptr}; diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index 21b36e5cb..370521d25 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -163,9 +163,10 @@ struct HashAggrJitSlot { int32_t offset; int32_t nullByte; uint8_t nullMask; - // All aggregate-level traits live in the descriptor; IR-side code reads them - // through 'desc'. Only the row-layout fields above are slot-specific. + HashAggrJitDescriptor desc; + + std::string getDescription() const; }; bool isHashAggrJitSupportedType(TypeKind kind); From 5ad55167a4e01ada18386cd4550be176cbc3dd9c Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 14 Jun 2026 21:30:29 +0800 Subject: [PATCH 77/98] part4: localize HashAggrJit scratch buffers Keep only planned JIT chunks on GroupingSet and allocate transient add/extract runtimes locally so per-call state does not persist on the operator. --- bolt/exec/GroupingSet.cpp | 85 ++++++++++++++++++++++----------------- bolt/exec/GroupingSet.h | 17 -------- 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 28d5a9db7..3b8d4e641 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1096,6 +1096,16 @@ void GroupingSet::runHashAggrJitAddChunks( } jitExecuted.assign(aggregates_.size(), 0); + std::vector decoded; + std::vector inputRuntimes; + std::vector> rowChildren; + std::vector> + rowChildPtrs; + // Keeps input vectors alive for the DecodedVector buffers referenced by JIT + // during addDense. + std::vector inputVectors; + std::vector inputRuntimePtrs; + std::vector newGroupPtrs; for (auto& chunk : hashAggrJitChunks_) { if (!chunk.isCodegenReady()) { VLOG(1) << "HashAggrJit chunk is not codegen-ready, skip add: " @@ -1104,12 +1114,12 @@ void GroupingSet::runHashAggrJitAddChunks( } const auto numSlots = chunk.slots().size(); - hashAggrJitDecoded_.resize(numSlots); - hashAggrJitInputRuntimes_.resize(numSlots); - hashAggrJitRowChildren_.resize(numSlots); - hashAggrJitRowChildPtrs_.resize(numSlots); - hashAggrJitInputVectors_.assign(numSlots, nullptr); - hashAggrJitInputRuntimePtrs_.assign(numSlots, nullptr); + decoded.resize(numSlots); + inputRuntimes.resize(numSlots); + rowChildren.resize(numSlots); + rowChildPtrs.resize(numSlots); + inputVectors.assign(numSlots, nullptr); + inputRuntimePtrs.assign(numSlots, nullptr); bool canRunChunk = true; std::string skipReason; @@ -1149,16 +1159,16 @@ void GroupingSet::runHashAggrJitAddChunks( skipReason = "lazy input with pushdown enabled"; break; } - hashAggrJitInputVectors_[slotIndex] = arg; - hashAggrJitDecoded_[slotIndex].decode(*arg, activeRows_); + inputVectors[slotIndex] = arg; + decoded[slotIndex].decode(*arg, activeRows_); const bool usesRowInputRuntime = slot.desc.inputShape() == jit::HashAggrJitRuntimeShape::Row; if (usesRowInputRuntime) { if (!fillHashAggrJitRowInputRuntime( - hashAggrJitInputRuntimes_[slotIndex], - hashAggrJitRowChildren_[slotIndex], - hashAggrJitRowChildPtrs_[slotIndex], - hashAggrJitDecoded_[slotIndex], + inputRuntimes[slotIndex], + rowChildren[slotIndex], + rowChildPtrs[slotIndex], + decoded[slotIndex], activeRows_, slot)) { canRunChunk = false; @@ -1166,16 +1176,14 @@ void GroupingSet::runHashAggrJitAddChunks( break; } } else { - hashAggrJitInputRuntimes_[slotIndex].scalar = - jit::HashAggrJitScalarInputRuntime{ - .values = hashAggrJitDecoded_[slotIndex].dataAsVoid(), - .indices = hashAggrJitDecoded_[slotIndex].indices(), - .nulls = hashAggrJitDecoded_[slotIndex].nulls(&activeRows_)}; + inputRuntimes[slotIndex].scalar = jit::HashAggrJitScalarInputRuntime{ + .values = decoded[slotIndex].dataAsVoid(), + .indices = decoded[slotIndex].indices(), + .nulls = decoded[slotIndex].nulls(&activeRows_)}; } - inputsMayHaveNulls = - inputsMayHaveNulls || hashAggrJitDecoded_[slotIndex].mayHaveNulls(); - hashAggrJitInputRuntimePtrs_[slotIndex] = - reinterpret_cast(&hashAggrJitInputRuntimes_[slotIndex]); + inputsMayHaveNulls = inputsMayHaveNulls || decoded[slotIndex].mayHaveNulls(); + inputRuntimePtrs[slotIndex] = + reinterpret_cast(&inputRuntimes[slotIndex]); } if (!canRunChunk) { @@ -1185,17 +1193,17 @@ void GroupingSet::runHashAggrJitAddChunks( } if (!newGroups.empty()) { - hashAggrJitNewGroups_.resize(newGroups.size()); + newGroupPtrs.resize(newGroups.size()); for (auto i = 0; i < newGroups.size(); ++i) { - hashAggrJitNewGroups_[i] = groups[newGroups[i]]; + newGroupPtrs[i] = groups[newGroups[i]]; } - chunk.init(hashAggrJitNewGroups_.data(), newGroups.size()); + chunk.init(newGroupPtrs.data(), newGroups.size()); } chunk.addDense( groups, activeRows_.end(), - hashAggrJitInputRuntimePtrs_.data(), + inputRuntimePtrs.data(), inputsMayHaveNulls); for (const auto& slot : chunk.slots()) { jitExecuted[slot.aggregateIndex] = 1; @@ -1217,12 +1225,18 @@ void GroupingSet::runHashAggrJitExtractChunks( } jitExtracted.assign(aggregates_.size(), 0); + std::vector outputRuntimes; + std::vector> + rowOutputChildren; + std::vector> + rowOutputChildPtrs; + std::vector resultPtrs; for (auto& chunk : hashAggrJitChunks_) { const auto numSlots = chunk.slots().size(); - hashAggrJitOutputRuntimes_.assign(numSlots, jit::HashAggrJitOutputRuntime{}); - hashAggrJitRowOutputChildren_.resize(numSlots); - hashAggrJitRowOutputChildPtrs_.resize(numSlots); - hashAggrJitResultPtrs_.assign(numSlots, nullptr); + outputRuntimes.assign(numSlots, jit::HashAggrJitOutputRuntime{}); + rowOutputChildren.resize(numSlots); + rowOutputChildPtrs.resize(numSlots); + resultPtrs.assign(numSlots, nullptr); bool canRunChunk = true; std::string skipReason; for (auto slotIndex = 0; slotIndex < numSlots; ++slotIndex) { @@ -1262,7 +1276,7 @@ void GroupingSet::runHashAggrJitExtractChunks( skipReason = "unsupported scalar output value kind"; break; } - hashAggrJitOutputRuntimes_[slotIndex].scalar = + outputRuntimes[slotIndex].scalar = jit::HashAggrJitScalarOutputRuntime{ .values = hashAggrJitRawOutputData( aggregateVector.get(), *outputKind), @@ -1272,24 +1286,23 @@ void GroupingSet::runHashAggrJitExtractChunks( aggregateVector->encoding() == VectorEncoding::Simple::ROW && slot.desc.outputShape() == jit::HashAggrJitRuntimeShape::Row) { if (!fillHashAggrJitRowOutputRuntime( - hashAggrJitOutputRuntimes_[slotIndex], - hashAggrJitRowOutputChildren_[slotIndex], - hashAggrJitRowOutputChildPtrs_[slotIndex], + outputRuntimes[slotIndex], + rowOutputChildren[slotIndex], + rowOutputChildPtrs[slotIndex], aggregateVector.get())) { canRunChunk = false; skipReason = "ROW output runtime requires flat scalar row children"; break; } } - hashAggrJitResultPtrs_[slotIndex] = - reinterpret_cast(&hashAggrJitOutputRuntimes_[slotIndex]); + resultPtrs[slotIndex] = reinterpret_cast(&outputRuntimes[slotIndex]); } if (!canRunChunk) { VLOG(1) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " << chunk.getDescription() << " reason=" << skipReason; continue; } - chunk.extract(groups.data(), groups.size(), hashAggrJitResultPtrs_.data()); + chunk.extract(groups.data(), groups.size(), resultPtrs.data()); for (const auto& slot : chunk.slots()) { jitExtracted[slot.aggregateIndex] = 1; } diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index bdd150b79..ed89d7ee6 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -463,23 +463,6 @@ class GroupingSet { #ifdef ENABLE_BOLT_JIT std::vector hashAggrJitChunks_; - std::vector hashAggrJitDecoded_; - std::vector hashAggrJitInputRuntimes_; - std::vector> - hashAggrJitRowChildren_; - std::vector> - hashAggrJitRowChildPtrs_; - // Keeps input vectors alive for the DecodedVector buffers referenced by - // JIT during addDense. - std::vector hashAggrJitInputVectors_; - std::vector hashAggrJitInputRuntimePtrs_; - std::vector hashAggrJitNewGroups_; - std::vector hashAggrJitOutputRuntimes_; - std::vector> - hashAggrJitRowOutputChildren_; - std::vector> - hashAggrJitRowOutputChildPtrs_; - std::vector hashAggrJitResultPtrs_; #endif // True if any aggregate accumulator allocates memory outside RowContainer's From d1991975644a5d3cc2e8fcf4aeea724138f7ba6e Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 11:09:28 +0800 Subject: [PATCH 78/98] add more benchmark cases --- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 62 ++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index d1e5389ca..5805db3de 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -28,6 +28,8 @@ namespace { struct HashAggrJitBenchmarkCase { std::shared_ptr plan; + int32_t minFuseWidth{4}; + int32_t maxFuseWidth{16}; }; enum class AggregationPlanKind { @@ -161,6 +163,27 @@ class HashAggrJitBenchmark : public VectorTestBase { name + "_merge_count", rows, counts, AggregationPlanKind::PartialFinal); } + void addFixedAggregateCountFuseWidthBenchmark( + const std::string& name, + int32_t aggregateCount) { + auto rows = makeRows(aggregateCount); + std::vector sums; + sums.reserve(aggregateCount); + for (auto i = 0; i < aggregateCount; ++i) { + sums.push_back(fmt::format("spark_sum(c{})", i + 1)); + } + + for (const auto fuseWidth : {4, 8, 16, 32}) { + addCase( + fmt::format("{}_aggs{}_fuse_width{}", name, aggregateCount, fuseWidth), + rows, + sums, + AggregationPlanKind::PartialFinal, + fuseWidth, + fuseWidth); + } + } + private: std::vector makeRows(int32_t width) { std::vector names; @@ -282,11 +305,19 @@ class HashAggrJitBenchmark : public VectorTestBase { return builder.planNode(); } - void run(const std::shared_ptr& plan, bool enableJit) { + void run( + const std::shared_ptr& plan, + bool enableJit, + int32_t minFuseWidth = 4, + int32_t maxFuseWidth = 16) { exec::test::AssertQueryBuilder(plan) .config(core::QueryConfig::kHashAggrJitEnabled, enableJit ? "true" : "false") - .config(core::QueryConfig::kHashAggrJitMinFuseWidth, "4") - .config(core::QueryConfig::kHashAggrJitMaxFuseWidth, "16") + .config( + core::QueryConfig::kHashAggrJitMinFuseWidth, + std::to_string(minFuseWidth)) + .config( + core::QueryConfig::kHashAggrJitMaxFuseWidth, + std::to_string(maxFuseWidth)) .copyResults(pool_.get()); } @@ -294,20 +325,32 @@ class HashAggrJitBenchmark : public VectorTestBase { const std::string& name, const std::vector& rows, const std::vector& aggregates, - AggregationPlanKind planKind = AggregationPlanKind::Single) { + AggregationPlanKind planKind = AggregationPlanKind::Single, + int32_t minFuseWidth = 4, + int32_t maxFuseWidth = 16) { auto testCase = std::make_unique(); testCase->plan = makePlan(rows, aggregates, planKind); + testCase->minFuseWidth = minFuseWidth; + testCase->maxFuseWidth = maxFuseWidth; // Warm up both paths so the benchmark compares steady-state execution and // doesn't charge one-time plan setup / JIT compilation to the first sample. - run(testCase->plan, false); - run(testCase->plan, true); + run(testCase->plan, false, minFuseWidth, maxFuseWidth); + run(testCase->plan, true, minFuseWidth, maxFuseWidth); auto* testCasePtr = testCase.get(); folly::addBenchmark(__FILE__, name + "_nojit", [this, testCasePtr]() { - run(testCasePtr->plan, false); + run( + testCasePtr->plan, + false, + testCasePtr->minFuseWidth, + testCasePtr->maxFuseWidth); return 1; }); folly::addBenchmark(__FILE__, name + "_jit", [this, testCasePtr]() { - run(testCasePtr->plan, true); + run( + testCasePtr->plan, + true, + testCasePtr->minFuseWidth, + testCasePtr->maxFuseWidth); return 1; }); folly::addBenchmark(__FILE__, "-", []() { return 0; }); @@ -347,6 +390,9 @@ int main(int argc, char** argv) { benchmark.addFloatingPointMinMaxBenchmark("width16", 16); benchmark.addFloatingPointMinMaxBenchmark("width32", 32); + benchmark.addFixedAggregateCountFuseWidthBenchmark( + "fixed_aggregate_count", 256); + // benchmark.addHighCardinalityExtractBenchmark("width4_high_card", 4); // benchmark.addHighCardinalityExtractBenchmark("width8_high_card", 8); // benchmark.addHighCardinalityExtractBenchmark("width16_high_card", 16); From 57bcf43059a0da06e7cf4b5fa4dafe88903eb7b3 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 15:36:35 +0800 Subject: [PATCH 79/98] fix failed uts --- bolt/exec/AggregateCompanionAdapter.cpp | 47 ++--------------- bolt/exec/AggregateCompanionAdapter.h | 50 +++++++++++++++---- bolt/exec/GroupingSet.cpp | 20 +++----- .../prestosql/aggregates/CountAggregate.cpp | 11 ++-- .../prestosql/aggregates/MinMaxAggregates.cpp | 7 +-- .../sparksql/aggregates/AverageAggregate.cpp | 15 +++--- .../sparksql/aggregates/DecimalSumAggregate.h | 8 +-- .../sparksql/aggregates/SumAggregate.cpp | 7 +-- bolt/jit/aggregation/HashAggrJit.cpp | 30 +++++------ bolt/jit/aggregation/HashAggrJit.h | 7 +-- bolt/jit/aggregation/HashAggrJitTypes.h | 42 +++++++++++++--- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 6 +-- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 4 +- 13 files changed, 132 insertions(+), 122 deletions(-) diff --git a/bolt/exec/AggregateCompanionAdapter.cpp b/bolt/exec/AggregateCompanionAdapter.cpp index f032c7c2b..5d185e127 100644 --- a/bolt/exec/AggregateCompanionAdapter.cpp +++ b/bolt/exec/AggregateCompanionAdapter.cpp @@ -38,19 +38,6 @@ namespace bytedance::bolt::exec { -namespace { - -#ifdef ENABLE_BOLT_JIT -jit::HashAggrJitPlanContext toUnderlyingMergeContext( - const jit::HashAggrJitPlanContext& context) { - auto adapted = context; - adapted.isRawInput = false; - return adapted; -} -#endif - -} // namespace - void AggregateCompanionFunctionBase::setOffsetsInternal( int32_t offset, int32_t nullByte, @@ -82,14 +69,14 @@ bool AggregateCompanionFunctionBase::supportsToIntermediate() const { #ifdef ENABLE_BOLT_JIT bool AggregateCompanionFunctionBase::supportsHashAggrJit( - const jit::HashAggrJitPlanContext& /*context*/) const { - return false; + const jit::HashAggrJitPlanContext& context) const { + return fn_->supportsHashAggrJit(rewriteHashAggrJitContext(context)); } std::optional AggregateCompanionFunctionBase::createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& /*context*/) const { - return std::nullopt; + const jit::HashAggrJitPlanContext& context) const { + return fn_->createHashAggrJitDescriptor(rewriteHashAggrJitContext(context)); } #endif @@ -194,19 +181,6 @@ void AggregateCompanionAdapter::PartialFunction::extractValues( fn_->extractAccumulators(groups, numGroups, result); } -#ifdef ENABLE_BOLT_JIT -bool AggregateCompanionAdapter::PartialFunction::supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const { - return fn_->supportsHashAggrJit(context); -} - -std::optional -AggregateCompanionAdapter::PartialFunction::createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const { - return fn_->createHashAggrJitDescriptor(context); -} -#endif - void AggregateCompanionAdapter::MergeFunction::addRawInput( char** groups, const SelectivityVector& rows, @@ -225,19 +199,6 @@ void AggregateCompanionAdapter::MergeFunction::addSingleGroupRawInput( fn_->addSingleGroupIntermediateResults(group, rows, args, mayPushdown); } -#ifdef ENABLE_BOLT_JIT -bool AggregateCompanionAdapter::MergeFunction::supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const { - return fn_->supportsHashAggrJit(toUnderlyingMergeContext(context)); -} - -std::optional -AggregateCompanionAdapter::MergeFunction::createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const { - return fn_->createHashAggrJitDescriptor(toUnderlyingMergeContext(context)); -} -#endif - void AggregateCompanionAdapter::MergeFunction::toIntermediate( const SelectivityVector& rows, std::vector& args, diff --git a/bolt/exec/AggregateCompanionAdapter.h b/bolt/exec/AggregateCompanionAdapter.h index e23b3c35d..6337f55b4 100644 --- a/bolt/exec/AggregateCompanionAdapter.h +++ b/bolt/exec/AggregateCompanionAdapter.h @@ -54,6 +54,15 @@ class AggregateCompanionFunctionBase : public Aggregate { bool supportsToIntermediate() const override final; #ifdef ENABLE_BOLT_JIT + // Companion functions expose a different aggregation stage than the operator + // step implies (e.g. a merge companion runs at the kSingle step but consumes + // intermediate state and emits intermediate state). Each companion rewrites + // the planning context into the stage it actually represents by flipping the + // isRawInput/isPartialOutput flags; the absolute types in the context are + // stage-agnostic, so the derived input/output views follow automatically. + virtual jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const = 0; + bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override; @@ -134,11 +143,15 @@ struct AggregateCompanionAdapter { : AggregateCompanionFunctionBase{std::move(fn), resultType} {} #ifdef ENABLE_BOLT_JIT - bool supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const override; - - std::optional createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const override; + // Partial companion: reads raw input and emits intermediate state. + jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const override { + auto rewritten = context; + rewritten.isRawInput = true; + rewritten.isPartialOutput = true; + rewritten.resultType = fn_->resultType(); + return rewritten; + } #endif void extractValues(char** groups, int32_t numGroups, VectorPtr* result) @@ -171,11 +184,15 @@ struct AggregateCompanionAdapter { bool mayPushdown) override; #ifdef ENABLE_BOLT_JIT - bool supportsHashAggrJit( - const jit::HashAggrJitPlanContext& context) const override; - - std::optional createHashAggrJitDescriptor( - const jit::HashAggrJitPlanContext& context) const override; + // Merge companion: reads intermediate state and emits intermediate state. + jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const override { + auto rewritten = context; + rewritten.isRawInput = false; + rewritten.isPartialOutput = true; + rewritten.resultType = fn_->resultType(); + return rewritten; + } #endif void extractValues(char** groups, int32_t numGroups, VectorPtr* result) @@ -189,6 +206,19 @@ struct AggregateCompanionAdapter { const TypePtr& resultType) : MergeFunction{std::move(fn), resultType} {} +#ifdef ENABLE_BOLT_JIT + // Merge-extract companion: reads intermediate state and emits the final + // result. + jit::HashAggrJitPlanContext rewriteHashAggrJitContext( + const jit::HashAggrJitPlanContext& context) const override { + auto rewritten = context; + rewritten.isRawInput = false; + rewritten.isPartialOutput = false; + rewritten.resultType = fn_->resultType(); + return rewritten; + } +#endif + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override; }; diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 3b8d4e641..07ed48761 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -241,19 +241,16 @@ std::optional makeHashAggrJitSlot( return std::nullopt; } - std::vector inputTypes; - if (isRawInput) { - inputTypes = aggregate.rawInputTypes; - } else { - inputTypes = {aggregate.intermediateType}; - } - + // Fill the stage-agnostic absolute types and let the context derive the + // stage-specific input/output view from the flags. Companion functions may + // later flip the flags (via rewriteHashAggrJitContext) without needing to + // re-pick types here. const jit::HashAggrJitPlanContext context{ .isRawInput = isRawInput, .isPartialOutput = isPartialOutput, - .inputTypes = std::move(inputTypes), - .outputType = isPartialOutput ? aggregate.intermediateType - : aggregate.function->resultType()}; + .rawInputTypes = aggregate.rawInputTypes, + .intermediateType = aggregate.intermediateType, + .resultType = aggregate.function->resultType()}; if (!aggregate.function->supportsHashAggrJit(context)) { return std::nullopt; } @@ -1020,8 +1017,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { currentChunkSlots.clear(); return; } - jit::HashAggrJitChunk chunk( - std::move(currentChunkSlots), isRawInput_, isPartial_); + jit::HashAggrJitChunk chunk(std::move(currentChunkSlots)); if (chunk.codegen()) { hashAggrJitChunks_.push_back(std::move(chunk)); VLOG(1) << "HashAggrJit formed chunk: " diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index fa2e8839c..59895976a 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -56,20 +56,21 @@ class CountAggregate : public SimpleNumericAggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { + const auto inputTypes = context.inputTypes(); if (context.isRawInput) { if (context.isCountStar()) { return true; } - if (context.inputTypes.size() != 1 || context.inputTypes[0] == nullptr) { + if (inputTypes.size() != 1 || inputTypes[0] == nullptr) { return false; } - const auto& inputType = context.inputTypes[0]; + const auto& inputType = inputTypes[0]; return !inputType->isRow() && (inputType->isDecimal() || jit::isHashAggrJitSupportedType(inputType->kind())); } - return context.inputTypes.size() == 1 && context.inputTypes[0] != nullptr && - context.inputTypes[0]->kind() == TypeKind::BIGINT; + return inputTypes.size() == 1 && inputTypes[0] != nullptr && + inputTypes[0]->kind() == TypeKind::BIGINT; } std::optional createHashAggrJitDescriptor( @@ -80,7 +81,7 @@ class CountAggregate : public SimpleNumericAggregate { auto inputKind = jit::HashAggrJitValueKind::Int64; if (!context.isCountStar()) { auto maybeInputKind = - jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); + jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); if (!maybeInputKind.has_value()) { return std::nullopt; } diff --git a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp index f3b1ca220..792cc82c1 100644 --- a/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -55,10 +55,11 @@ class MinMaxAggregate : public SimpleNumericAggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputTypes.size() != 1 || context.inputTypes[0] == nullptr) { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || inputTypes[0] == nullptr) { return false; } - const auto& inputType = context.inputTypes[0]; + const auto& inputType = inputTypes[0]; return !inputType->isRow() && (inputType->isDecimal() || jit::isHashAggrJitSupportedType(inputType->kind()) || @@ -70,7 +71,7 @@ class MinMaxAggregate : public SimpleNumericAggregate { if (!supportsHashAggrJit(context)) { return std::nullopt; } - auto inputKind = jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); if (!inputKind.has_value()) { return std::nullopt; } diff --git a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp index 12c06591e..c9d4ec9e0 100644 --- a/bolt/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/AverageAggregate.cpp @@ -46,10 +46,11 @@ class AverageAggregate #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { return false; } - const auto& inputType = context.inputTypes[0]; + const auto& inputType = inputTypes[0]; if (context.isRawInput) { if (inputType->isDecimal()) { return false; @@ -77,7 +78,7 @@ class AverageAggregate .ops = jit::getAvgOps()}; } - auto inputKind = jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); if (!inputKind.has_value()) { return std::nullopt; } @@ -144,10 +145,11 @@ class DecimalAverageAggregate : public DecimalAggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { return false; } - const auto& inputType = context.inputTypes[0]; + const auto& inputType = inputTypes[0]; if (context.isRawInput) { return inputType->isDecimal(); } @@ -161,7 +163,8 @@ class DecimalAverageAggregate : public DecimalAggregate { if (!supportsHashAggrJit(context)) { return std::nullopt; } - const auto& inputType = context.inputTypes[0]; + const auto inputTypes = context.inputTypes(); + const auto& inputType = inputTypes[0]; const auto& valueType = context.isRawInput ? inputType : inputType->childAt(0); return jit::HashAggrJitDescriptor{ diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 730d738e6..5d565953c 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -63,10 +63,11 @@ class DecimalSumAggregate : public exec::Aggregate { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { return false; } - const auto& inputType = context.inputTypes[0]; + const auto& inputType = inputTypes[0]; if (context.isRawInput) { return inputType->isDecimal() && (inputType->isShortDecimal() || inputType->isLongDecimal()); @@ -81,7 +82,8 @@ class DecimalSumAggregate : public exec::Aggregate { if (!supportsHashAggrJit(context)) { return std::nullopt; } - const auto& inputType = context.inputTypes[0]; + const auto inputTypes = context.inputTypes(); + const auto& inputType = inputTypes[0]; const auto& valueType = context.isRawInput ? inputType : inputType->childAt(0); return jit::HashAggrJitDescriptor{ diff --git a/bolt/functions/sparksql/aggregates/SumAggregate.cpp b/bolt/functions/sparksql/aggregates/SumAggregate.cpp index a92deebc5..acef1c147 100644 --- a/bolt/functions/sparksql/aggregates/SumAggregate.cpp +++ b/bolt/functions/sparksql/aggregates/SumAggregate.cpp @@ -46,10 +46,11 @@ class SumAggregate : public SumAggregateBase { #ifdef ENABLE_BOLT_JIT bool supportsHashAggrJit( const jit::HashAggrJitPlanContext& context) const override { - if (context.inputTypes.size() != 1 || !context.inputTypes[0]) { + const auto inputTypes = context.inputTypes(); + if (inputTypes.size() != 1 || !inputTypes[0]) { return false; } - const auto& inputType = context.inputTypes[0]; + const auto& inputType = inputTypes[0]; if (inputType->isRow() || inputType->isDecimal()) { return false; } @@ -63,7 +64,7 @@ class SumAggregate : public SumAggregateBase { return std::nullopt; } - auto inputKind = jit::hashAggrJitValueKind(context.inputTypes[0]->kind()); + auto inputKind = jit::hashAggrJitValueKind(context.inputTypes()[0]->kind()); if (!inputKind.has_value()) { return std::nullopt; } diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 79f5993af..f3429c933 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1006,8 +1006,7 @@ bool genAddDenseIR( bool genExtractIR( llvm::Module& module, const std::string& fn, - const std::vector& slots, - bool partialOutput) { + const std::vector& slots) { auto& context = module.getContext(); llvm::IRBuilder<> builder(context); HashAggrJitCodegen codegen(module); @@ -1052,8 +1051,9 @@ bool genExtractIR( output = std::make_unique(codegen, outputRuntime); } - auto* extractFn = partialOutput ? slot.desc.ops->extractAccumulators - : slot.desc.ops->extractResults; + auto* extractFn = slot.desc.context.isPartialOutput + ? slot.desc.ops->extractAccumulators + : slot.desc.ops->extractResults; if (extractFn == nullptr) { return false; } @@ -1074,13 +1074,14 @@ bool genExtractIR( } // namespace std::string HashAggrJitSlot::getDescription() const { + const auto inputTypes = desc.context.inputTypes(); std::ostringstream inputs; inputs << "["; - for (size_t i = 0; i < desc.context.inputTypes.size(); ++i) { + for (size_t i = 0; i < inputTypes.size(); ++i) { if (i > 0) { inputs << ","; } - inputs << desc.context.inputTypes[i]->toString(); + inputs << inputTypes[i]->toString(); } inputs << "]"; @@ -1090,22 +1091,15 @@ std::string HashAggrJitSlot::getDescription() const { desc.context.isRawInput, desc.context.isPartialOutput, inputs.str(), - desc.context.outputType->toString(), + desc.context.outputType()->toString(), offset); } -HashAggrJitChunk::HashAggrJitChunk( - std::vector slots, - bool isRawInput, - bool isPartialOutput) - : slots_(std::move(slots)), - isRawInput_(isRawInput), - isPartialOutput_(isPartialOutput) { +HashAggrJitChunk::HashAggrJitChunk(std::vector slots) + : slots_(std::move(slots)) { const auto description = getDescription(); functionName_ = fmt::format( - "jit_hashaggr_v2_raw{}_partial{}_n{}_h{:016x}", - isRawInput_, - isPartialOutput_, + "jit_hashaggr_v2_n{}_h{:016x}", slots_.size(), bits::hashBytes(1, description.data(), description.size())); } @@ -1216,7 +1210,7 @@ bool HashAggrJitChunk::codegen() { const bool ok = genInitIR(module, initFn, slots_) && genAddDenseIR(module, addFn, slots_, true) && genAddDenseIR(module, addNoNullFn, slots_, false) && - genExtractIR(module, extractFn, slots_, isPartialOutput_); + genExtractIR(module, extractFn, slots_); const bool hasError = !ok; logHashAggrJitFunctionIR(module, moduleKey, initFn, "init", hasError); logHashAggrJitFunctionIR(module, moduleKey, addFn, "add_dense", hasError); diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index e528ebbf2..94d477040 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -276,10 +276,7 @@ using HashAggrJitExtractFunc = void (*)(char** groups, int32_t numGroups, char** class HashAggrJitChunk { public: - explicit HashAggrJitChunk( - std::vector slots, - bool isRawInput, - bool partialOutput); + explicit HashAggrJitChunk(std::vector slots); bool codegen(); @@ -319,8 +316,6 @@ class HashAggrJitChunk { private: std::vector slots_; - bool isRawInput_{false}; - bool isPartialOutput_{false}; std::string functionName_; CompiledModuleSP module_; HashAggrJitInitFunc init_{nullptr}; diff --git a/bolt/jit/aggregation/HashAggrJitTypes.h b/bolt/jit/aggregation/HashAggrJitTypes.h index 370521d25..982acbca2 100644 --- a/bolt/jit/aggregation/HashAggrJitTypes.h +++ b/bolt/jit/aggregation/HashAggrJitTypes.h @@ -79,14 +79,40 @@ union HashAggrJitOutputRuntime { HashAggrJitRowOutputRuntime row; }; +// Stage-agnostic, absolute description of an aggregate's types. The three type +// fields below always hold the same values regardless of which stage +// (raw/intermediate/partial/final) this context represents; the active stage is +// selected purely by the isRawInput/isPartialOutput flags. The inputTypes() and +// outputType() accessors derive the stage-specific view from those flags, so a +// flag flip (e.g. by a companion function) automatically yields a consistent +// input/output type without any separate type-rewrite step. struct HashAggrJitPlanContext { bool isRawInput{false}; bool isPartialOutput{false}; - std::vector inputTypes; - TypePtr outputType; + // Original (raw) input argument types. Empty for count(*). + std::vector rawInputTypes; + // The intermediate accumulator type (i.e. the partial output / merge input). + TypePtr intermediateType; + // The final aggregate result type. + TypePtr resultType; + + // Stage-derived input view: raw inputs when reading raw input, otherwise the + // single intermediate accumulator type. + std::vector inputTypes() const { + if (isRawInput) { + return rawInputTypes; + } + return {intermediateType}; + } + + // Stage-derived output view: the intermediate accumulator type for partial + // output, otherwise the final result type. + TypePtr outputType() const { + return isPartialOutput ? intermediateType : resultType; + } bool isCountStar() const { - return isRawInput && inputTypes.empty(); + return isRawInput && rawInputTypes.empty(); } }; @@ -143,16 +169,16 @@ struct HashAggrJitDescriptor { } HashAggrJitRuntimeShape inputShape() const { - return context.inputTypes.size() == 1 && context.inputTypes[0] && - context.inputTypes[0]->isRow() + const auto inputTypes = context.inputTypes(); + return inputTypes.size() == 1 && inputTypes[0] && inputTypes[0]->isRow() ? HashAggrJitRuntimeShape::Row : HashAggrJitRuntimeShape::Scalar; } HashAggrJitRuntimeShape outputShape() const { - return context.outputType && context.outputType->isRow() - ? HashAggrJitRuntimeShape::Row - : HashAggrJitRuntimeShape::Scalar; + const auto outputType = context.outputType(); + return outputType && outputType->isRow() ? HashAggrJitRuntimeShape::Row + : HashAggrJitRuntimeShape::Scalar; } // std::string signature() const; diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index b7dfad7f4..79da8465f 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -89,7 +89,7 @@ void compileDecimalAvgAddIntermediateResults( auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "decimal_avg_merge", function, continueBlock); const auto [sumPrecision, _] = - hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes[0]); + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); const auto sumKind = hashAggrJitDecimalKindForPrecision(sumPrecision); auto* sumRow = input.readRowField(row, 0, sumKind); auto* countRow = input.readRowField(row, 1, HashAggrJitValueKind::Int64); @@ -133,9 +133,9 @@ void emitDecimalAvgExtract( bool partialOutput) { auto& b = codegen.builder(); const auto [inputPrecision, inputScale] = - hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes[0]); + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); const auto [outputPrecision, outputScale] = - hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType); + hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType()); const bool longDecimal = hashAggrJitDecimalKindForPrecision( outputPrecision) == HashAggrJitValueKind::Int128; const char* fn = partialOutput diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 2c7da8399..95ee5d387 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -87,7 +87,7 @@ void compileDecimalSumAddIntermediateResults( auto* mergeBlock = llvm::BasicBlock::Create( codegen.module().getContext(), "decimal_sum_merge", function, continueBlock); const auto [sumPrecision, _] = - hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes[0]); + hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); const auto sumKind = hashAggrJitDecimalKindForPrecision(sumPrecision); auto* sumRow = input.readRowField(row, 0, sumKind); auto* incomingIsEmpty = @@ -138,7 +138,7 @@ void emitDecimalSumExtract( // long/short decimal and overflow precision are decided by the actual // output decimal type of this aggregation stage. const auto [outPrecision, outScale] = - hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType); + hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType()); const bool longDecimal = hashAggrJitDecimalKindForPrecision(outPrecision) == HashAggrJitValueKind::Int128; const char* fn = partialOutput From ebf11e74777087cd4a2439f3bf4b6c44876ed5df Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 17:05:47 +0800 Subject: [PATCH 80/98] enable not emit frame pointer to use perf --- CMakeLists.txt | 10 ++++++++++ Makefile | 8 ++++++++ conanfile.py | 3 +++ 3 files changed, 21 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c72259f3..e097acf03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -458,6 +458,16 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsigned-char") endif() +# Keep frame pointers and avoid sibling-call optimization so perf can fully +# unwind stacks with the cheap frame-pointer call-graph (no DWARF needed). +# Off by default; enable with -DBOLT_ENABLE_FRAME_POINTER=ON when profiling. +option(BOLT_ENABLE_FRAME_POINTER + "Preserve frame pointers for perf stack unwinding" OFF) +if(BOLT_ENABLE_FRAME_POINTER) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -fno-optimize-sibling-calls") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -fno-optimize-sibling-calls") +endif() + # Under Ninja, we are able to designate certain targets large enough to require restricted # parallelism. if("${MAX_HIGH_MEM_JOBS}") diff --git a/Makefile b/Makefile index 4b2d1a8d5..eec5a4ddd 100644 --- a/Makefile +++ b/Makefile @@ -83,6 +83,8 @@ BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS ?= "OFF" # Control whether to build tests with coverage instrumentation BOLT_BUILD_TESTING_WITH_COVERAGE ?= "OFF" +# Control whether to keep frame pointers (for perf stack unwinding) +BOLT_ENABLE_FRAME_POINTER ?= "OFF" # ----------------------------------------------------------------- # TODO: remove `BUILD_USER` and `BUILD_CHANNEL` @@ -204,6 +206,7 @@ conan_build: conan_install NUM_THREADS=$(NUM_THREADS) \ BOLT_BUILD_BENCHMARKS=${BOLT_BUILD_BENCHMARKS} \ BOLT_BUILD_TESTING_WITH_COVERAGE=${BOLT_BUILD_TESTING_WITH_COVERAGE} \ + BOLT_ENABLE_FRAME_POINTER=${BOLT_ENABLE_FRAME_POINTER} \ conan build ../.. --name=bolt --version=${BUILD_VERSION} --user=${BUILD_USER} --channel=${BUILD_CHANNEL} \ -s llvm-core/*:build_type=Release \ -s "&:build_type=${BUILD_TYPE}" \ @@ -306,6 +309,11 @@ benchmarks-build-spark: benchmarks-build-relwithdebinfo: $(MAKE) conan_build BUILD_TYPE=RelWithDebInfo BOLT_BUILD_BENCHMARKS="ON" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=False -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" +# Same as benchmarks-build-spark but keeps frame pointers so perf can unwind +# stacks with the cheap frame-pointer call-graph (no DWARF needed). +benchmarks-build-spark-profile: + $(MAKE) conan_build BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS="ON" BOLT_ENABLE_FRAME_POINTER="ON" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=True -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" + unittest_debug: unittest unittest: debug_with_test ctest --test-dir $(BUILD_BASE_DIR)/Debug --timeout 7200 -j $(NUM_THREADS) --output-on-failure diff --git a/conanfile.py b/conanfile.py index d28518e9b..47b38ed5b 100644 --- a/conanfile.py +++ b/conanfile.py @@ -569,6 +569,9 @@ def generate(self): elif os.getenv("BOLT_BUILD_BENCHMARKS_BASIC", "OFF") == "ON": tc.cache_variables["BOLT_BUILD_BENCHMARKS_BASIC"] = "ON" + if os.getenv("BOLT_ENABLE_FRAME_POINTER", "OFF") == "ON": + tc.cache_variables["BOLT_ENABLE_FRAME_POINTER"] = "ON" + tc.generate() # generate conantoolchain.cmake & xxx-config.cmake From a6ff4294d430f7238847ab7b81908bb0c0ea1b2b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 19:07:58 +0800 Subject: [PATCH 81/98] enabled profiling jit with intel vtune --- Makefile | 11 ++++ bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 4 +- bolt/jit/CMakeLists.txt | 24 ++++++++ bolt/jit/ThrustJITv2.cpp | 61 +++++++++++++++++++ conanfile.py | 6 ++ 5 files changed, 104 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index eec5a4ddd..3817355d2 100644 --- a/Makefile +++ b/Makefile @@ -85,6 +85,9 @@ BOLT_BUILD_BENCHMARKS ?= "OFF" BOLT_BUILD_TESTING_WITH_COVERAGE ?= "OFF" # Control whether to keep frame pointers (for perf stack unwinding) BOLT_ENABLE_FRAME_POINTER ?= "OFF" +# Control whether to report JIT symbols to Intel VTune (jitprofiling) +BOLT_ENABLE_VTUNE_JIT ?= "OFF" +VTUNE_SDK_DIR ?= # ----------------------------------------------------------------- # TODO: remove `BUILD_USER` and `BUILD_CHANNEL` @@ -207,6 +210,8 @@ conan_build: conan_install BOLT_BUILD_BENCHMARKS=${BOLT_BUILD_BENCHMARKS} \ BOLT_BUILD_TESTING_WITH_COVERAGE=${BOLT_BUILD_TESTING_WITH_COVERAGE} \ BOLT_ENABLE_FRAME_POINTER=${BOLT_ENABLE_FRAME_POINTER} \ + BOLT_ENABLE_VTUNE_JIT=${BOLT_ENABLE_VTUNE_JIT} \ + VTUNE_SDK_DIR=${VTUNE_SDK_DIR} \ conan build ../.. --name=bolt --version=${BUILD_VERSION} --user=${BUILD_USER} --channel=${BUILD_CHANNEL} \ -s llvm-core/*:build_type=Release \ -s "&:build_type=${BUILD_TYPE}" \ @@ -314,6 +319,12 @@ benchmarks-build-relwithdebinfo: benchmarks-build-spark-profile: $(MAKE) conan_build BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS="ON" BOLT_ENABLE_FRAME_POINTER="ON" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=True -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" +# Same as benchmarks-build-spark-profile but also reports JIT symbols to Intel +# VTune (libjitprofiling). Override the SDK path with VTUNE_SDK_DIR=... if it is +# not under the default /opt/intel/oneapi/vtune/2023.2.0/sdk. +benchmarks-build-spark-vtune: + $(MAKE) conan_build BUILD_TYPE=Release BOLT_BUILD_BENCHMARKS="ON" BOLT_ENABLE_FRAME_POINTER="ON" BOLT_ENABLE_VTUNE_JIT="ON" VTUNE_SDK_DIR="${VTUNE_SDK_DIR}" CONAN_CONFIG=" -c bolt/*:tools.build:skip_test=False" CONAN_OPTIONS="-o bolt/*:spark_compatible=True -o bolt/*:enable_testutil=True -o bolt/*:enable_perf=True" + unittest_debug: unittest unittest: debug_with_test ctest --test-dir $(BUILD_BASE_DIR)/Debug --timeout 7200 -j $(NUM_THREADS) --output-on-failure diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 5805db3de..be49bd4eb 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -334,8 +334,8 @@ class HashAggrJitBenchmark : public VectorTestBase { testCase->maxFuseWidth = maxFuseWidth; // Warm up both paths so the benchmark compares steady-state execution and // doesn't charge one-time plan setup / JIT compilation to the first sample. - run(testCase->plan, false, minFuseWidth, maxFuseWidth); - run(testCase->plan, true, minFuseWidth, maxFuseWidth); + // run(testCase->plan, false, minFuseWidth, maxFuseWidth); + // run(testCase->plan, true, minFuseWidth, maxFuseWidth); auto* testCasePtr = testCase.get(); folly::addBenchmark(__FILE__, name + "_nojit", [this, testCasePtr]() { run( diff --git a/bolt/jit/CMakeLists.txt b/bolt/jit/CMakeLists.txt index 4962548ed..1f4473fb3 100644 --- a/bolt/jit/CMakeLists.txt +++ b/bolt/jit/CMakeLists.txt @@ -29,6 +29,30 @@ bolt_add_library( target_link_libraries(bolt_thrustjit PUBLIC llvm-core::llvm-core date::date fmt::fmt Folly::folly) +# Optional: report JIT-generated symbols to Intel VTune via the JIT Profiling +# API (libjitprofiling). Enable with -DBOLT_ENABLE_VTUNE_JIT=ON and point +# VTUNE_SDK_DIR at the VTune SDK (e.g. /opt/intel/oneapi/vtune//sdk). +# Runtime reporting is further gated by the BOLT_JIT_VTUNE env var. +if(BOLT_ENABLE_VTUNE_JIT) + if(NOT VTUNE_SDK_DIR) + set(VTUNE_SDK_DIR "/opt/intel/oneapi/vtune/2023.2.0/sdk") + endif() + find_path(VTUNE_JIT_INCLUDE_DIR jitprofiling.h + HINTS "${VTUNE_SDK_DIR}/include") + find_library(VTUNE_JIT_LIBRARY jitprofiling + HINTS "${VTUNE_SDK_DIR}/lib64" "${VTUNE_SDK_DIR}/lib32") + if(NOT VTUNE_JIT_INCLUDE_DIR OR NOT VTUNE_JIT_LIBRARY) + message(FATAL_ERROR + "BOLT_ENABLE_VTUNE_JIT=ON but jitprofiling.h/libjitprofiling not found " + "under VTUNE_SDK_DIR=${VTUNE_SDK_DIR}") + endif() + target_compile_definitions(bolt_thrustjit PRIVATE BOLT_ENABLE_VTUNE_JIT) + target_include_directories(bolt_thrustjit PRIVATE "${VTUNE_JIT_INCLUDE_DIR}") + # libjitprofiling needs dl/pthread at link time. + target_link_libraries(bolt_thrustjit PRIVATE "${VTUNE_JIT_LIBRARY}" ${CMAKE_DL_LIBS}) + message(STATUS "bolt_thrustjit: VTune JIT profiling enabled (${VTUNE_JIT_LIBRARY})") +endif() + target_compile_options( bolt_thrustjit PRIVATE $<$:-Werror=return-type> ) diff --git a/bolt/jit/ThrustJITv2.cpp b/bolt/jit/ThrustJITv2.cpp index 1adae7e6f..03e43e019 100644 --- a/bolt/jit/ThrustJITv2.cpp +++ b/bolt/jit/ThrustJITv2.cpp @@ -33,8 +33,13 @@ #include #include #include +#include #include +#ifdef BOLT_ENABLE_VTUNE_JIT +#include +#endif + namespace bytedance::bolt::jit { namespace { @@ -94,6 +99,57 @@ void appendPerfMap( std::fclose(file); } +#ifdef BOLT_ENABLE_VTUNE_JIT +// Returns whether VTune JIT symbol reporting is enabled. Unlike perf's +// /tmp/perf-.map (which VTune does not read), VTune resolves JIT code only +// when the process actively reports each function via the Intel JIT Profiling +// API (iJIT_NotifyEvent). Controlled by the BOLT_JIT_VTUNE env var to avoid the +// reporting overhead in normal runs. +bool vtuneJitEnabled() { + static const bool enabled = std::getenv("BOLT_JIT_VTUNE") != nullptr; + return enabled; +} + +// Reports symbols of a freshly loaded JIT object to VTune through the Intel JIT +// Profiling API so that VTune can attribute samples in JIT-generated machine +// code to function names instead of "outside any known module". +void notifyVTune( + const llvm::object::ObjectFile& obj, + const llvm::RuntimeDyld::LoadedObjectInfo& loadedInfo) { + // Use the relocated/loaded object so symbol addresses are the runtime ones. + auto debugObjOwner = loadedInfo.getObjectForDebug(obj); + const llvm::object::ObjectFile& debugObj = *debugObjOwner.getBinary(); + + static std::mutex vtuneMutex; + std::lock_guard guard(vtuneMutex); + + for (const auto& [symbol, size] : + llvm::object::computeSymbolSizes(debugObj)) { + auto typeOr = symbol.getType(); + if (!typeOr || *typeOr != llvm::object::SymbolRef::ST_Function) { + continue; + } + auto addrOr = symbol.getAddress(); + auto nameOr = symbol.getName(); + if (!addrOr || !nameOr || size == 0) { + llvm::consumeError(addrOr.takeError()); + llvm::consumeError(nameOr.takeError()); + continue; + } + // iJIT_Method_Load.method_name is a non-const char*; keep a stable owning + // copy for the duration of the iJIT_NotifyEvent call. + std::string name = nameOr->str(); + iJIT_Method_Load jit_method = {}; + jit_method.method_id = iJIT_GetNewMethodID(); + jit_method.method_name = name.data(); + jit_method.method_load_address = reinterpret_cast(*addrOr); + jit_method.method_size = static_cast(size); + iJIT_NotifyEvent( + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, static_cast(&jit_method)); + } +} +#endif // BOLT_ENABLE_VTUNE_JIT + } // namespace llvm::Expected> ThrustJITv2::Create() { @@ -129,6 +185,11 @@ llvm::Expected> ThrustJITv2::Create() { if (perfMapEnabled()) { appendPerfMap(obj, loadedInfo); } +#ifdef BOLT_ENABLE_VTUNE_JIT + if (vtuneJitEnabled()) { + notifyVTune(obj, loadedInfo); + } +#endif llvm::orc::ResourceKey resourceKey = 0; if (auto err = mr.withResourceKeyDo( [&](llvm::orc::ResourceKey key) { diff --git a/conanfile.py b/conanfile.py index 47b38ed5b..07264d738 100644 --- a/conanfile.py +++ b/conanfile.py @@ -572,6 +572,12 @@ def generate(self): if os.getenv("BOLT_ENABLE_FRAME_POINTER", "OFF") == "ON": tc.cache_variables["BOLT_ENABLE_FRAME_POINTER"] = "ON" + if os.getenv("BOLT_ENABLE_VTUNE_JIT", "OFF") == "ON": + tc.cache_variables["BOLT_ENABLE_VTUNE_JIT"] = "ON" + vtune_sdk_dir = os.getenv("VTUNE_SDK_DIR") + if vtune_sdk_dir: + tc.cache_variables["VTUNE_SDK_DIR"] = vtune_sdk_dir + tc.generate() # generate conantoolchain.cmake & xxx-config.cmake From 64f7bebfbe394db44a557063f39d6aa5a1c3af3b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 21:09:24 +0800 Subject: [PATCH 82/98] fix diff in decimal sum --- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 4 +- bolt/jit/aggregation/HashAggrJit.cpp | 13 ++++--- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 8 +++- .../runtime/HashAggrDecimalRuntime.cpp | 38 +++++++++++++------ 4 files changed, 43 insertions(+), 20 deletions(-) diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index be49bd4eb..5805db3de 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -334,8 +334,8 @@ class HashAggrJitBenchmark : public VectorTestBase { testCase->maxFuseWidth = maxFuseWidth; // Warm up both paths so the benchmark compares steady-state execution and // doesn't charge one-time plan setup / JIT compilation to the first sample. - // run(testCase->plan, false, minFuseWidth, maxFuseWidth); - // run(testCase->plan, true, minFuseWidth, maxFuseWidth); + run(testCase->plan, false, minFuseWidth, maxFuseWidth); + run(testCase->plan, true, minFuseWidth, maxFuseWidth); auto* testCasePtr = testCase.get(); folly::addBenchmark(__FILE__, name + "_nojit", [this, testCasePtr]() { run( diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index f3429c933..e0c49c5fc 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -94,7 +94,8 @@ void jit_HashAggrExtractFinalShortDecimalSum( char* group, int32_t offset, int32_t precision, - int32_t scale); + int32_t scale, + int32_t accumulatorIsNull); [[maybe_unused]] __attribute__((used)) const void* const kHashAggrRuntimeLinkAnchors[] = { @@ -123,27 +124,27 @@ void ensureBuiltinDeclarations(llvm::Module& module) { declareFunction(module, "jit_HashAggrResizeVector", voidTy, {i8PtrTy, i32Ty}); // Decimal extract helpers. - // Sum: (vector, row, group, offset, precision, scale). + // Sum: (vector, row, group, offset, precision, scale, accumulatorIsNull). declareFunction( module, "jit_HashAggrExtractFinalShortDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); declareFunction( module, "jit_HashAggrExtractFinalLongDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); declareFunction( module, "jit_HashAggrExtractPartialShortDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); declareFunction( module, "jit_HashAggrExtractPartialLongDecimalSum", voidTy, - {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty}); + {i8PtrTy, i32Ty, i8PtrTy, i32Ty, i32Ty, i32Ty, i32Ty}); // Avg: (vector, row, group, offset, precision, scale, resultPrecision, // resultScale). diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 95ee5d387..0e90423c0 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -146,6 +146,11 @@ void emitDecimalSumExtract( : "jit_HashAggrExtractPartialShortDecimalSum") : (longDecimal ? "jit_HashAggrExtractFinalLongDecimalSum" : "jit_HashAggrExtractFinalShortDecimalSum"); + // Mirror the non-JIT extract's leading `if (isNull(group))` check: a group + // whose accumulator null flag is set (e.g. an overflowed intermediate result + // merged in) must produce null, regardless of the sum/isEmpty fields. + auto* accumulatorIsNull = + b.CreateZExt(codegen.isAccumulatorNull(group, slot), b.getInt32Ty()); b.CreateCall( codegen.module().getFunction(fn), {vector, @@ -153,7 +158,8 @@ void emitDecimalSumExtract( group, b.getInt32(slot.offset), b.getInt32(outPrecision), - b.getInt32(outScale)}); + b.getInt32(outScale), + accumulatorIsNull}); } void compileDecimalSumExtractAccumulators( diff --git a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp index 208bb9a89..4718c1b55 100644 --- a/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp +++ b/bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp @@ -115,12 +115,15 @@ void jitHashAggrExtractFinalDecimalSum( int32_t row, char* group, int32_t offset, - int32_t precision) { + int32_t precision, + bool accumulatorIsNull) { auto* state = reinterpret_cast(group + offset); auto* flat = reinterpret_cast(vector) ->asUnchecked>(); - if (state->isEmpty) { + // A null accumulator (e.g. an overflowed intermediate result merged into the + // group) produces null, matching the non-JIT `if (isNull(group))` check. + if (accumulatorIsNull || state->isEmpty) { flat->setNull(row, true); return; } @@ -140,7 +143,8 @@ void jitHashAggrExtractPartialDecimalSum( int32_t row, char* group, int32_t offset, - int32_t precision) { + int32_t precision, + bool accumulatorIsNull) { auto* state = reinterpret_cast(group + offset); auto* rowVector = reinterpret_cast(vector) @@ -150,6 +154,14 @@ void jitHashAggrExtractPartialDecimalSum( auto* isEmptyVector = rowVector->childAt(1)->asUnchecked>(); rowVector->setNull(row, false); + // A null accumulator (e.g. an overflowed intermediate result merged into the + // group) outputs sum=0, isEmpty=true, matching the non-JIT + // `if (isNull(group))` branch. + if (accumulatorIsNull) { + sumVector->set(row, 0); + isEmptyVector->set(row, true); + return; + } if (state->isEmpty) { sumVector->set(row, 0); isEmptyVector->set(row, true); @@ -233,9 +245,10 @@ jit_HashAggrExtractFinalShortDecimalSum( char* group, int32_t offset, int32_t precision, - int32_t /*scale*/) { + int32_t /*scale*/, + int32_t accumulatorIsNull) { jitHashAggrExtractFinalDecimalSum( - vector, row, group, offset, precision); + vector, row, group, offset, precision, accumulatorIsNull != 0); } __attribute__((__visibility__("default"))) void @@ -245,9 +258,10 @@ jit_HashAggrExtractFinalLongDecimalSum( char* group, int32_t offset, int32_t precision, - int32_t /*scale*/) { + int32_t /*scale*/, + int32_t accumulatorIsNull) { jitHashAggrExtractFinalDecimalSum( - vector, row, group, offset, precision); + vector, row, group, offset, precision, accumulatorIsNull != 0); } // Partial decimal sum extract: write row(sum:decimal, isEmpty:bool). @@ -258,9 +272,10 @@ jit_HashAggrExtractPartialShortDecimalSum( char* group, int32_t offset, int32_t precision, - int32_t /*scale*/) { + int32_t /*scale*/, + int32_t accumulatorIsNull) { jitHashAggrExtractPartialDecimalSum( - vector, row, group, offset, precision); + vector, row, group, offset, precision, accumulatorIsNull != 0); } __attribute__((__visibility__("default"))) void @@ -270,9 +285,10 @@ jit_HashAggrExtractPartialLongDecimalSum( char* group, int32_t offset, int32_t precision, - int32_t /*scale*/) { + int32_t /*scale*/, + int32_t accumulatorIsNull) { jitHashAggrExtractPartialDecimalSum( - vector, row, group, offset, precision); + vector, row, group, offset, precision, accumulatorIsNull != 0); } // Partial decimal avg extract: write row(sum:decimal, count:bigint). From 380fd84f13d4f39304840a2aae3a7591b502c519 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 21:43:57 +0800 Subject: [PATCH 83/98] fix diff in decimal avg --- bolt/jit/aggregation/ops/DecimalAvgOps.cpp | 24 ++++++++-------------- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 18 ++++------------ 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp index 79da8465f..86503ffec 100644 --- a/bolt/jit/aggregation/ops/DecimalAvgOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalAvgOps.cpp @@ -5,6 +5,7 @@ #ifdef ENABLE_BOLT_JIT +#include #include #include "bolt/jit/aggregation/HashAggrJit.h" @@ -129,11 +130,14 @@ void emitDecimalAvgExtract( llvm::Value* vector, llvm::Value* row, llvm::Value* group, - const HashAggrJitSlot& slot, - bool partialOutput) { + const HashAggrJitSlot& slot) { auto& b = codegen.builder(); - const auto [inputPrecision, inputScale] = + const bool partialOutput = slot.desc.context.isPartialOutput; + auto [inputPrecision, inputScale] = hashAggrJitDecimalPrecisionScale(slot.desc.context.inputTypes()[0]); + if (slot.desc.isRawInput()) { + inputPrecision = std::min(38, inputPrecision + 10); + } const auto [outputPrecision, outputScale] = hashAggrJitDecimalPrecisionScale(slot.desc.context.outputType()); const bool longDecimal = hashAggrJitDecimalKindForPrecision( @@ -161,12 +165,7 @@ void compileDecimalAvgExtractAccumulators( const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { emitDecimalAvgExtract( - codegen, - target.output.vector(), - target.row, - group, - slot, - /*partialOutput=*/true); + codegen, target.output.vector(), target.row, group, slot); } void compileDecimalAvgExtractValues( @@ -175,12 +174,7 @@ void compileDecimalAvgExtractValues( const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { emitDecimalAvgExtract( - codegen, - target.output.vector(), - target.row, - group, - slot, - /*partialOutput=*/false); + codegen, target.output.vector(), target.row, group, slot); } } // namespace diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 0e90423c0..2f947aac1 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -132,9 +132,9 @@ void emitDecimalSumExtract( llvm::Value* vector, llvm::Value* row, llvm::Value* group, - const HashAggrJitSlot& slot, - bool partialOutput) { + const HashAggrJitSlot& slot) { auto& b = codegen.builder(); + const bool partialOutput = slot.desc.context.isPartialOutput; // long/short decimal and overflow precision are decided by the actual // output decimal type of this aggregation stage. const auto [outPrecision, outScale] = @@ -168,12 +168,7 @@ void compileDecimalSumExtractAccumulators( const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { emitDecimalSumExtract( - codegen, - target.output.vector(), - target.row, - group, - slot, - /*partialOutput=*/true); + codegen, target.output.vector(), target.row, group, slot); } void compileDecimalSumExtractValues( @@ -182,12 +177,7 @@ void compileDecimalSumExtractValues( const HashAggrJitSlot& slot, const HashAggrJitExtractTarget& target) { emitDecimalSumExtract( - codegen, - target.output.vector(), - target.row, - group, - slot, - /*partialOutput=*/false); + codegen, target.output.vector(), target.row, group, slot); } } // namespace From 5bff5c57e593394138f4f4538680f5844cea642b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 22:28:56 +0800 Subject: [PATCH 84/98] add test config --- bolt/core/QueryConfig.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bolt/core/QueryConfig.h b/bolt/core/QueryConfig.h index 1d96c5ebf..08b646c1b 100644 --- a/bolt/core/QueryConfig.h +++ b/bolt/core/QueryConfig.h @@ -1613,11 +1613,11 @@ class QueryConfig { } bool enableHashAggrJit() const { - return get(kHashAggrJitEnabled, false); + return get(kHashAggrJitEnabled, true); } int32_t hashAggrJitMinFuseWidth() const { - return get(kHashAggrJitMinFuseWidth, 4); + return get(kHashAggrJitMinFuseWidth, 1); } int32_t hashAggrJitMaxFuseWidth() const { From 89fe122f81640c944defdb24945edc77ed8b3fd8 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 15 Jun 2026 22:58:19 +0800 Subject: [PATCH 85/98] fix another decimal sum bug --- bolt/jit/aggregation/ops/DecimalSumOps.cpp | 64 +++++++++++++++------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/bolt/jit/aggregation/ops/DecimalSumOps.cpp b/bolt/jit/aggregation/ops/DecimalSumOps.cpp index 2f947aac1..0428537e8 100644 --- a/bolt/jit/aggregation/ops/DecimalSumOps.cpp +++ b/bolt/jit/aggregation/ops/DecimalSumOps.cpp @@ -203,28 +203,54 @@ void emitDecimalAddWithOverflow( auto* i64Ty = b.getInt64Ty(); auto* zero128 = llvm::ConstantInt::get(i128Ty, 0); - auto* oldSum = codegen.loadValue(group, i128Ty, sumOffset); - auto* newSum = b.CreateAdd(oldSum, addend); - codegen.storeValue(group, i128Ty, sumOffset, newSum); + auto* lhs = codegen.loadValue(group, i128Ty, sumOffset); + auto* rhs = addend; + + // Mirror DecimalUtil::addWithOverflow + addUnsignedValues exactly (i128 sum + // kept as low 127 bits plus a separate overflow counter), instead of a + // sign-flip heuristic which diverges once the true i128 magnitude overflows. + // + // same sign: + // mag = (|lhs| + |rhs|) & ~(1<<127) // low 127 bits + // carry = (|lhs| + |rhs|) >> 127 // bit 127 + // both negative -> sum = -mag, overflow = -carry + // both positive -> sum = mag, overflow = carry + // different sign: + // sum = lhs + rhs, overflow = 0 + auto* lhsNeg = b.CreateICmpSLT(lhs, zero128); + auto* rhsNeg = b.CreateICmpSLT(rhs, zero128); + auto* sameSign = b.CreateICmpEQ(lhsNeg, rhsNeg); + auto* bothNeg = b.CreateAnd(lhsNeg, rhsNeg); + + // Magnitudes for the same-sign path: negate operands when both negative so + // the unsigned addition operates on |lhs|, |rhs| (matches addUnsignedValues). + auto* absLhs = b.CreateSelect(bothNeg, b.CreateNeg(lhs), lhs); + auto* absRhs = b.CreateSelect(bothNeg, b.CreateNeg(rhs), rhs); + auto* unsignedSum = b.CreateAdd(absLhs, absRhs); + auto* mask127 = + llvm::ConstantInt::get(i128Ty, llvm::APInt::getSignedMinValue(128)); + // mag = unsignedSum & ~(1<<127) + auto* magnitude = b.CreateAnd(unsignedSum, b.CreateNot(mask127)); + // carry = (unsignedSum >> 127) & 1, as i64 + auto* carryBit = b.CreateAnd( + b.CreateLShr(unsignedSum, llvm::ConstantInt::get(i128Ty, 127)), + llvm::ConstantInt::get(i128Ty, 1)); + auto* carry64 = b.CreateTrunc(carryBit, i64Ty); + + auto* sameSignSum = b.CreateSelect(bothNeg, b.CreateNeg(magnitude), magnitude); + auto* sameSignOverflow = + b.CreateSelect(bothNeg, b.CreateNeg(carry64), carry64); - // Mirror jitHashAggrAddWithOverflow: - // +1 if a>0 && b>0 && result<0 (positive overflow) - // -1 if a<0 && b<0 && result>=0 (negative overflow) - auto* aPos = b.CreateICmpSGT(oldSum, zero128); - auto* bPos = b.CreateICmpSGT(addend, zero128); - auto* rNeg = b.CreateICmpSLT(newSum, zero128); - auto* posOverflow = b.CreateAnd(b.CreateAnd(aPos, bPos), rNeg); - - auto* aNeg = b.CreateICmpSLT(oldSum, zero128); - auto* bNeg = b.CreateICmpSLT(addend, zero128); - auto* rNonNeg = b.CreateICmpSGE(newSum, zero128); - auto* negOverflow = b.CreateAnd(b.CreateAnd(aNeg, bNeg), rNonNeg); - - auto* carry = b.CreateSub( - b.CreateZExt(posOverflow, i64Ty), b.CreateZExt(negOverflow, i64Ty)); + auto* diffSignSum = b.CreateAdd(lhs, rhs); + + auto* newSum = b.CreateSelect(sameSign, sameSignSum, diffSignSum); + auto* overflowDelta = + b.CreateSelect(sameSign, sameSignOverflow, b.getInt64(0)); + + codegen.storeValue(group, i128Ty, sumOffset, newSum); auto* oldOverflow = codegen.loadValue(group, i64Ty, overflowOffset); codegen.storeValue( - group, i64Ty, overflowOffset, b.CreateAdd(oldOverflow, carry)); + group, i64Ty, overflowOffset, b.CreateAdd(oldOverflow, overflowDelta)); } } // namespace bytedance::bolt::jit From 7fc6b85f4e8fb72c8dd004fec975cf6b665b905a Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Tue, 16 Jun 2026 11:51:20 +0800 Subject: [PATCH 86/98] change hash aggr jit log level --- bolt/exec/GroupingSet.cpp | 24 ++++++++++++------------ bolt/jit/aggregation/HashAggrJit.cpp | 7 ++----- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 07ed48761..7088d0984 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -987,7 +987,7 @@ const SelectivityVector& GroupingSet::getSelectivityVector( void GroupingSet::maybeCreateHashAggrJitPlan() { hashAggrJitChunks_.clear(); if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { - VLOG(1) << "HashAggrJit plan disabled: enableHashAggrJit=" + LOG(INFO) << "HashAggrJit plan disabled: enableHashAggrJit=" << queryConfig_.enableHashAggrJit() << " isGlobal=" << isGlobal_ << " ignoreNullKeys=" << ignoreNullKeys_; return; @@ -998,7 +998,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { const auto maxFuseWidth = std::max(1, queryConfig_.hashAggrJitMaxFuseWidth()); const auto minChunkWidth = minFuseWidth; - VLOG(1) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ + LOG(INFO) << "HashAggrJit planning starts: isRawInput=" << isRawInput_ << " isPartial=" << isPartial_ << " aggregates=" << aggregates_.size() << " minFuseWidth=" << minFuseWidth @@ -1010,7 +1010,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { auto flushChunk = [&]() { if (currentChunkSlots.size() < minChunkWidth) { if (!currentChunkSlots.empty()) { - VLOG(1) << "HashAggrJit discard chunk candidate due to width " + LOG(INFO) << "HashAggrJit discard chunk candidate due to width " << currentChunkSlots.size() << " < " << minChunkWidth << "."; } @@ -1020,10 +1020,10 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { jit::HashAggrJitChunk chunk(std::move(currentChunkSlots)); if (chunk.codegen()) { hashAggrJitChunks_.push_back(std::move(chunk)); - VLOG(1) << "HashAggrJit formed chunk: " + LOG(INFO) << "HashAggrJit formed chunk: " << hashAggrJitChunks_.back().getDescription(); } else { - VLOG(1) << "HashAggrJit chunk codegen failed for chunk " + LOG(INFO) << "HashAggrJit chunk codegen failed for chunk " << chunk.functionName(); } currentChunkSlots.clear(); @@ -1033,7 +1033,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { for (auto i = 0; i < aggregates_.size(); ++i) { auto slot = makeHashAggrJitSlot(i, aggregates_[i], isRawInput_, isPartial_); if (!slot.has_value()) { - VLOG(1) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" + LOG(INFO) << "HashAggrJit aggregate is not JIT-able: agg#" << i << "(" << aggregates_[i].name << ") isRawInput=" << isRawInput_ << " isPartialOutput=" << isPartial_ << " inputTypes=[" << [&]() { @@ -1071,7 +1071,7 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { } flushChunk(); - VLOG(1) << "HashAggrJit planning finished: totalChunks=" + LOG(INFO) << "HashAggrJit planning finished: totalChunks=" << hashAggrJitChunks_.size(); } @@ -1083,7 +1083,7 @@ void GroupingSet::runHashAggrJitAddChunks( std::vector& jitExecuted) { if (hashAggrJitChunks_.empty() || hasSpilled() || bypassProbeHT_ || supportRowBasedOutput_ || !activeRows_.isAllSelected()) { - VLOG(1) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() + LOG(INFO) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() << " hasSpilled=" << hasSpilled() << " bypassProbeHT=" << bypassProbeHT_ << " supportRowBasedOutput=" << supportRowBasedOutput_ @@ -1104,7 +1104,7 @@ void GroupingSet::runHashAggrJitAddChunks( std::vector newGroupPtrs; for (auto& chunk : hashAggrJitChunks_) { if (!chunk.isCodegenReady()) { - VLOG(1) << "HashAggrJit chunk is not codegen-ready, skip add: " + LOG(INFO) << "HashAggrJit chunk is not codegen-ready, skip add: " << chunk.getDescription(); continue; } @@ -1183,7 +1183,7 @@ void GroupingSet::runHashAggrJitAddChunks( } if (!canRunChunk) { - VLOG(1) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " + LOG(INFO) << "HashAggrJit chunk cannot run add path, fallback to non-JIT: " << chunk.getDescription() << " reason=" << skipReason; continue; } @@ -1214,7 +1214,7 @@ void GroupingSet::runHashAggrJitExtractChunks( std::vector& jitExtracted) { if (hashAggrJitChunks_.empty() || groups.empty() || hasSpilled() || supportRowBasedOutput_) { - VLOG(1) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() + LOG(INFO) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() << " groups=" << groups.size() << " hasSpilled=" << hasSpilled() << " supportRowBasedOutput=" << supportRowBasedOutput_; return; @@ -1294,7 +1294,7 @@ void GroupingSet::runHashAggrJitExtractChunks( resultPtrs[slotIndex] = reinterpret_cast(&outputRuntimes[slotIndex]); } if (!canRunChunk) { - VLOG(1) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " + LOG(INFO) << "HashAggrJit chunk cannot run extract path, fallback to non-JIT: " << chunk.getDescription() << " reason=" << skipReason; continue; } diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index e0c49c5fc..bfcda02da 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -33,12 +33,9 @@ void logHashAggrJitFunctionIR( llvm::StringRef functionName, llvm::StringRef stage, bool hasError) { - if (!VLOG_IS_ON(1)) { - return; - } const auto* function = module.getFunction(functionName); if (function == nullptr) { - VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + LOG(INFO) << "HashAggrJit generated LLVM IR for chunk " << moduleKey << " stage=" << stage.str() << " function=" << functionName.str() << " error=" << hasError << ": "; return; @@ -47,7 +44,7 @@ void logHashAggrJitFunctionIR( llvm::raw_string_ostream out(ir); function->print(out); out.flush(); - VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + LOG(INFO) << "HashAggrJit generated LLVM IR for chunk " << moduleKey << " stage=" << stage.str() << " function=" << functionName.str() << " error=" << hasError << ":\n" << ir; From c6acb2c2abcbc96afa2dbbb2cfd57a87ef60feed Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 17 Jun 2026 15:48:17 +0800 Subject: [PATCH 87/98] don't rely on numNulls_ when hash aggr jit is enabled --- bolt/exec/Aggregate.h | 47 ++++++++++++++----- bolt/exec/AggregationHook.h | 18 ++++--- bolt/exec/GroupingSet.cpp | 1 + bolt/exec/tests/AggregationTest.cpp | 18 +++++++ .../lib/aggregates/AverageAggregateBase.h | 4 +- .../aggregates/CentralMomentsAggregatesBase.h | 4 +- .../lib/aggregates/DecimalAggregate.h | 4 +- .../lib/aggregates/SimpleNumericAggregate.h | 4 +- .../lib/aggregates/SumAggregateBase.h | 2 +- .../lib/aggregates/tests/SumTestBase.h | 4 +- .../prestosql/aggregates/CountAggregate.cpp | 13 +++-- .../aggregates/VarianceAggregates.cpp | 4 +- .../prestosql/aggregates/tests/SumTest.cpp | 7 +-- .../sparksql/aggregates/DecimalSumAggregate.h | 4 +- 14 files changed, 94 insertions(+), 40 deletions(-) diff --git a/bolt/exec/Aggregate.h b/bolt/exec/Aggregate.h index bbe2013e9..d6ec0c88f 100644 --- a/bolt/exec/Aggregate.h +++ b/bolt/exec/Aggregate.h @@ -33,6 +33,8 @@ #include #include +#include + #include "bolt/common/memory/HashStringAllocator.h" #include "bolt/core/PlanNode.h" #include "bolt/core/QueryConfig.h" @@ -160,6 +162,10 @@ class Aggregate { setOffsetsInternal(offset, nullByte, nullMask, rowSizeOffset); } + void markNullCountUnknown() { + numNulls_ = std::nullopt; + } + // Initializes null flags and accumulators for newly encountered groups. This // function should be called only once for each group. // @@ -387,7 +393,15 @@ class Aggregate { } bool isNull(char* group) const { - return numNulls_ && (group[nullByte_] & nullMask_); + return mayHaveNulls() && (group[nullByte_] & nullMask_); + } + + bool hasNoNulls() const { + return numNulls_.has_value() && *numNulls_ == 0; + } + + bool mayHaveNulls() const { + return !numNulls_.has_value() || *numNulls_ > 0; } // Sets null flag for all specified groups to true. @@ -396,7 +410,9 @@ class Aggregate { for (auto i : indices) { groups[i][nullByte_] |= nullMask_; } - numNulls_ += indices.size(); + if (numNulls_.has_value()) { + *numNulls_ += indices.size(); + } } inline bool setNull(char* group) { @@ -404,18 +420,23 @@ class Aggregate { return false; } group[nullByte_] |= nullMask_; - ++numNulls_; + if (numNulls_.has_value()) { + ++*numNulls_; + } return true; } inline bool clearNull(char* group) { - if (numNulls_) { - uint8_t mask = group[nullByte_]; - if (mask & nullMask_) { - group[nullByte_] = mask & ~nullMask_; - --numNulls_; - return true; + if (!mayHaveNulls()) { + return false; + } + uint8_t mask = group[nullByte_]; + if (mask & nullMask_) { + group[nullByte_] = mask & ~nullMask_; + if (numNulls_.has_value()) { + --*numNulls_; } + return true; } return false; } @@ -476,9 +497,11 @@ class Aggregate { int32_t rowSizeOffset_ = 0; // Number of null accumulators in the current state of the aggregation - // operator for this aggregate. If 0, clearing the null as part of update - // is not needed. - uint64_t numNulls_ = 0; + // operator for this aggregate. + // - 0 => known that no group is null + // - N > 0 => known exact null count + // - nullopt => unknown; must rely on per-group null bit + std::optional numNulls_{0}; HashStringAllocator* allocator_{nullptr}; memory::MemoryPool* pool_{nullptr}; std::shared_ptr expressionEvaluator_{nullptr}; diff --git a/bolt/exec/AggregationHook.h b/bolt/exec/AggregationHook.h index b1f25b1d0..390c3fb18 100644 --- a/bolt/exec/AggregationHook.h +++ b/bolt/exec/AggregationHook.h @@ -30,6 +30,8 @@ #pragma once +#include + #include "bolt/common/base/CheckedArithmetic.h" #include "bolt/common/base/Range.h" #include "bolt/vector/LazyVector.h" @@ -60,7 +62,7 @@ class AggregationHook : public ValueHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls) + std::optional& numNulls) : offset_(offset), nullByte_(nullByte), nullMask_(nullMask), @@ -91,11 +93,13 @@ class AggregationHook : public ValueHook { } inline bool clearNull(char* group) { - if (*numNulls_) { + if (!numNulls_.has_value() || *numNulls_ > 0) { uint8_t mask = group[nullByte_]; if (mask & nullMask_) { group[nullByte_] = mask & clearNullMask_; - --*numNulls_; + if (numNulls_.has_value()) { + --*numNulls_; + } return true; } } @@ -108,7 +112,7 @@ class AggregationHook : public ValueHook { const uint8_t nullMask_; const uint8_t clearNullMask_; char* const* const groups_; - uint64_t* numNulls_; + std::optional& numNulls_; }; namespace { @@ -132,7 +136,7 @@ class SumHook final : public AggregationHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls) + std::optional& numNulls) : AggregationHook(offset, nullByte, nullMask, groups, numNulls) {} Kind kind() const override { @@ -174,7 +178,7 @@ class SimpleCallableHook final : public AggregationHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls, + std::optional& numNulls, UpdateSingleValue updateSingleValue) : AggregationHook(offset, nullByte, nullMask, groups, numNulls), updateSingleValue_(updateSingleValue) {} @@ -203,7 +207,7 @@ class MinMaxHook final : public AggregationHook { int32_t nullByte, uint8_t nullMask, char** groups, - uint64_t* numNulls) + std::optional& numNulls) : AggregationHook(offset, nullByte, nullMask, groups, numNulls) {} Kind kind() const override { diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 7088d0984..3c0338239 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1202,6 +1202,7 @@ void GroupingSet::runHashAggrJitAddChunks( inputRuntimePtrs.data(), inputsMayHaveNulls); for (const auto& slot : chunk.slots()) { + aggregates_[slot.aggregateIndex].function->markNullCountUnknown(); jitExecuted[slot.aggregateIndex] = 1; } } diff --git a/bolt/exec/tests/AggregationTest.cpp b/bolt/exec/tests/AggregationTest.cpp index 04d819a0c..bc4fc2ef1 100644 --- a/bolt/exec/tests/AggregationTest.cpp +++ b/bolt/exec/tests/AggregationTest.cpp @@ -814,6 +814,24 @@ TEST_P(AggregationTest, setNull) { EXPECT_TRUE(aggregate.isNullTest(&group)); } +TEST_P(AggregationTest, isNullUsesBitWhenNullCountUnknown) { + AggregateFunc aggregate(BIGINT()); + int32_t nullOffset = 0; + aggregate.setOffsets( + 0, + RowContainer::nullByte(nullOffset), + RowContainer::nullMask(nullOffset), + 0); + + char group{0}; + EXPECT_TRUE(aggregate.setNullTest(&group)); + aggregate.markNullCountUnknown(); + + EXPECT_TRUE(aggregate.isNullTest(&group)); + EXPECT_TRUE(aggregate.clearNullTest(&group)); + EXPECT_FALSE(aggregate.isNullTest(&group)); +} + TEST_P(AggregationTest, hashmodes) { rng_.seed(1); auto rowType = diff --git a/bolt/functions/lib/aggregates/AverageAggregateBase.h b/bolt/functions/lib/aggregates/AverageAggregateBase.h index aff93ec4a..d24e1825c 100644 --- a/bolt/functions/lib/aggregates/AverageAggregateBase.h +++ b/bolt/functions/lib/aggregates/AverageAggregateBase.h @@ -208,7 +208,7 @@ class AverageAggregateBase : public exec::Aggregate { groups[i], TAccumulator(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i]); @@ -251,7 +251,7 @@ class AverageAggregateBase : public exec::Aggregate { group, TAccumulator(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { const TInput* data = decodedRaw_.data(); TAccumulator totalSum(0); rows.applyToSelected([&](vector_size_t i) { totalSum += data[i]; }); diff --git a/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h b/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h index 4e9445d77..b639784d0 100644 --- a/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h +++ b/bolt/functions/lib/aggregates/CentralMomentsAggregatesBase.h @@ -263,7 +263,7 @@ class CentralMomentsAggregatesBase : public exec::Aggregate { } updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i]); @@ -294,7 +294,7 @@ class CentralMomentsAggregatesBase : public exec::Aggregate { updateNonNullValue(group, decodedRaw_.valueAt(i)); } }); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); CentralMomentsAccumulator accData; rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); diff --git a/bolt/functions/lib/aggregates/DecimalAggregate.h b/bolt/functions/lib/aggregates/DecimalAggregate.h index 6a0f9e496..9f34c4eec 100644 --- a/bolt/functions/lib/aggregates/DecimalAggregate.h +++ b/bolt/functions/lib/aggregates/DecimalAggregate.h @@ -132,7 +132,7 @@ class DecimalAggregate : public exec::Aggregate { groups[i], TResultType(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], TResultType(data[i])); @@ -178,7 +178,7 @@ class DecimalAggregate : public exec::Aggregate { group, TResultType(decodedRaw_.valueAt(i))); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { const TInputType* data = decodedRaw_.data(); LongDecimalWithOverflowState accumulator; rows.applyToSelected([&](vector_size_t i) { diff --git a/bolt/functions/lib/aggregates/SimpleNumericAggregate.h b/bolt/functions/lib/aggregates/SimpleNumericAggregate.h index ce71fcfda..1a4eb1c37 100644 --- a/bolt/functions/lib/aggregates/SimpleNumericAggregate.h +++ b/bolt/functions/lib/aggregates/SimpleNumericAggregate.h @@ -116,7 +116,7 @@ class SimpleNumericAggregate : public exec::Aggregate { exec::Aggregate::nullByte_, exec::Aggregate::nullMask_, groups, - &this->exec::Aggregate::numNulls_, + this->exec::Aggregate::numNulls_, updateSingleValue); auto indices = decoded.indices(); @@ -236,7 +236,7 @@ class SimpleNumericAggregate : public exec::Aggregate { exec::Aggregate::nullByte_, exec::Aggregate::nullMask_, groups, - &this->exec::Aggregate::numNulls_); + this->exec::Aggregate::numNulls_); // The decoded vector does not really keep the info from the 'rows', except // for the 'upper bound' of it. In case not all rows are selected we need to // generate proper indices, which we 'indirect' through the ones we got from diff --git a/bolt/functions/lib/aggregates/SumAggregateBase.h b/bolt/functions/lib/aggregates/SumAggregateBase.h index acafb3766..89c978b49 100644 --- a/bolt/functions/lib/aggregates/SumAggregateBase.h +++ b/bolt/functions/lib/aggregates/SumAggregateBase.h @@ -173,7 +173,7 @@ class SumAggregateBase return; } - if (exec::Aggregate::numNulls_) { + if (exec::Aggregate::mayHaveNulls()) { BaseAggregate::template updateGroups( groups, rows, arg, &updateSingleValue, false); } else { diff --git a/bolt/functions/lib/aggregates/tests/SumTestBase.h b/bolt/functions/lib/aggregates/tests/SumTestBase.h index 564d7770c..e3a9a0ffa 100644 --- a/bolt/functions/lib/aggregates/tests/SumTestBase.h +++ b/bolt/functions/lib/aggregates/tests/SumTestBase.h @@ -53,13 +53,13 @@ void testHookLimits(bool expectOverflow = false) { sumRow.sum = 0; ResultType expected = 0; char* row = reinterpret_cast(&sumRow); - uint64_t numNulls = 0; + std::optional numNulls = 0; bytedance::bolt::aggregate::SumHook hook( offsetof(SumRow, sum), offsetof(SumRow, nulls), 0, &row, - &numNulls); + numNulls); // Adding limit should not overflow. ASSERT_NO_THROW(hook.addValue(0, &limit)); diff --git a/bolt/functions/prestosql/aggregates/CountAggregate.cpp b/bolt/functions/prestosql/aggregates/CountAggregate.cpp index 59895976a..62c6f9860 100644 --- a/bolt/functions/prestosql/aggregates/CountAggregate.cpp +++ b/bolt/functions/prestosql/aggregates/CountAggregate.cpp @@ -116,15 +116,22 @@ class CountAggregate : public SimpleNumericAggregate { folly::Range indices) override { for (auto i : indices) { // result of count is never null + groups[i][nullByte_] &= ~nullMask_; *value(groups[i]) = (int64_t)0; } } FLATTEN void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { - BaseAggregate::doExtractValues(groups, numGroups, result, [&](char* group) { - return *value(group); - }); + auto* vector = (*result)->as>(); + BOLT_CHECK(vector); + vector->resize(numGroups); + vector->clearAllNulls(); + + auto* rawValues = vector->mutableRawValues(); + for (vector_size_t i = 0; i < numGroups; ++i) { + rawValues[i] = *value(groups[i]); + } } void addRawInput( diff --git a/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp b/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp index b48e44a8d..5e1f08d05 100644 --- a/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp +++ b/bolt/functions/prestosql/aggregates/VarianceAggregates.cpp @@ -252,7 +252,7 @@ class VarianceAggregate : public exec::Aggregate { updateNonNullValue(groups[i], decodedRaw_.valueAt(i)); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i]); @@ -294,7 +294,7 @@ class VarianceAggregate : public exec::Aggregate { updateNonNullValue(group, decodedRaw_.valueAt(i)); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { const T* data = decodedRaw_.data(); VarianceAccumulator accData; rows.applyToSelected([&](vector_size_t i) { accData.update(data[i]); }); diff --git a/bolt/functions/prestosql/aggregates/tests/SumTest.cpp b/bolt/functions/prestosql/aggregates/tests/SumTest.cpp index 13c880c06..fc73946cc 100644 --- a/bolt/functions/prestosql/aggregates/tests/SumTest.cpp +++ b/bolt/functions/prestosql/aggregates/tests/SumTest.cpp @@ -446,18 +446,19 @@ TEST_F(SumTest, hook) { sumRow.sum = 0; char* row = reinterpret_cast(&sumRow); - uint64_t numNulls = 1; + std::optional numNulls = 1; aggregate::SumHook hook( offsetof(SumRow, sum), offsetof(SumRow, nulls), 1, &row, - &numNulls); + numNulls); int64_t value = 11; hook.addValue(0, &value); EXPECT_EQ(0, sumRow.nulls); - EXPECT_EQ(0, numNulls); + ASSERT_TRUE(numNulls.has_value()); + EXPECT_EQ(0, *numNulls); EXPECT_EQ(value, sumRow.sum); } diff --git a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h index 5d565953c..9a310f23c 100644 --- a/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h +++ b/bolt/functions/sparksql/aggregates/DecimalSumAggregate.h @@ -231,7 +231,7 @@ class DecimalSumAggregate : public exec::Aggregate { groups[i], decodedRaw_.valueAt(i), false); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); rows.applyToSelected([&](vector_size_t i) { updateNonNullValue(groups[i], data[i], false); @@ -273,7 +273,7 @@ class DecimalSumAggregate : public exec::Aggregate { group, decodedRaw_.valueAt(i), false); }, nulls); - } else if (!exec::Aggregate::numNulls_ && decodedRaw_.isIdentityMapping()) { + } else if (exec::Aggregate::hasNoNulls() && decodedRaw_.isIdentityMapping()) { auto data = decodedRaw_.data(); DecimalSum decimalSum; rows.applyToSelected([&](vector_size_t i) { From 2828261b81dfae01f6375660da9d95017736b477 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 17 Jun 2026 16:30:44 +0800 Subject: [PATCH 88/98] add more metrics about hash aggr jit --- bolt/common/base/AggregationStats.h | 7 +++++++ bolt/exec/GroupingSet.cpp | 20 ++++++++++++++------ bolt/exec/Operator.cpp | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/bolt/common/base/AggregationStats.h b/bolt/common/base/AggregationStats.h index dc07b70b0..31c1a8fc9 100644 --- a/bolt/common/base/AggregationStats.h +++ b/bolt/common/base/AggregationStats.h @@ -28,5 +28,12 @@ struct AggregationStats { uint64_t aggOutputTimeNs{0}; uint64_t aggProbeBypassTimeNs{0}; uint64_t aggProbeBypassCount{0}; + // Hash aggregation JIT fine-grained timing. + // One-time codegen (LLVM compile) time for the JIT plan. + uint64_t aggJitCodegenTimeNs{0}; + // JIT-executed part of the agg function update time. + uint64_t aggFunctionJitTimeNs{0}; + // JIT-executed part of the extracting groups time. + uint64_t aggExtractGroupsJitTimeNs{0}; }; } // namespace bytedance::bolt::common diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 3c0338239..8c0581dad 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -985,6 +985,7 @@ const SelectivityVector& GroupingSet::getSelectivityVector( #ifdef ENABLE_BOLT_JIT void GroupingSet::maybeCreateHashAggrJitPlan() { + NanosecondTimer codegenTimer(&stats_.aggJitCodegenTimeNs); hashAggrJitChunks_.clear(); if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { LOG(INFO) << "HashAggrJit plan disabled: enableHashAggrJit=" @@ -1193,14 +1194,18 @@ void GroupingSet::runHashAggrJitAddChunks( for (auto i = 0; i < newGroups.size(); ++i) { newGroupPtrs[i] = groups[newGroups[i]]; } + NanosecondTimer jitTimer(&stats_.aggFunctionJitTimeNs); chunk.init(newGroupPtrs.data(), newGroups.size()); } - chunk.addDense( - groups, - activeRows_.end(), - inputRuntimePtrs.data(), - inputsMayHaveNulls); + { + NanosecondTimer jitTimer(&stats_.aggFunctionJitTimeNs); + chunk.addDense( + groups, + activeRows_.end(), + inputRuntimePtrs.data(), + inputsMayHaveNulls); + } for (const auto& slot : chunk.slots()) { aggregates_[slot.aggregateIndex].function->markNullCountUnknown(); jitExecuted[slot.aggregateIndex] = 1; @@ -1299,7 +1304,10 @@ void GroupingSet::runHashAggrJitExtractChunks( << chunk.getDescription() << " reason=" << skipReason; continue; } - chunk.extract(groups.data(), groups.size(), resultPtrs.data()); + { + NanosecondTimer jitTimer(&stats_.aggExtractGroupsJitTimeNs); + chunk.extract(groups.data(), groups.size(), resultPtrs.data()); + } for (const auto& slot : chunk.slots()) { jitExtracted[slot.aggregateIndex] = 1; } diff --git a/bolt/exec/Operator.cpp b/bolt/exec/Operator.cpp index e599e88ea..8ec7c8691 100644 --- a/bolt/exec/Operator.cpp +++ b/bolt/exec/Operator.cpp @@ -833,6 +833,25 @@ void Operator::recordGroupingSetStats(const common::AggregationStats& stats) { "aggProbeBypassCount", RuntimeCounter{(int64_t)stats.aggProbeBypassCount}); } + if (stats.aggJitCodegenTimeNs) { + lockedStats->addRuntimeStat( + "aggJitCodegenTimeNs", + RuntimeCounter{ + (int64_t)stats.aggJitCodegenTimeNs, RuntimeCounter::Unit::kNanos}); + } + if (stats.aggFunctionJitTimeNs) { + lockedStats->addRuntimeStat( + "aggFunctionJitTimeNs", + RuntimeCounter{ + (int64_t)stats.aggFunctionJitTimeNs, RuntimeCounter::Unit::kNanos}); + } + if (stats.aggExtractGroupsJitTimeNs) { + lockedStats->addRuntimeStat( + "aggExtractGroupsJitTimeNs", + RuntimeCounter{ + (int64_t)stats.aggExtractGroupsJitTimeNs, + RuntimeCounter::Unit::kNanos}); + } } void Operator::recordHashBuildSpillStats( From 34c7bd4101aac4ef92fb52d974a213a564d5d570 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 17 Jun 2026 16:49:33 +0800 Subject: [PATCH 89/98] hash aggr jit: compile chunks asynchronously in parallel Submit each chunk's codegen to the global CPU executor instead of compiling synchronously on the first batch. Chunks start not-ready (std::atomic ready_) and the query thread falls back to the non-JIT path until compilation completes, then switches to JIT. Also removes benchmark warmup so first-batch compile latency is measured. --- bolt/exec/GroupingSet.cpp | 62 +++++++++++++++---- bolt/exec/GroupingSet.h | 14 ++++- bolt/exec/benchmarks/HashAggrJitBenchmark.cpp | 7 +-- bolt/jit/aggregation/HashAggrJit.cpp | 12 +++- bolt/jit/aggregation/HashAggrJit.h | 7 ++- 5 files changed, 81 insertions(+), 21 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 8c0581dad..e8912707c 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -29,6 +29,7 @@ */ #include "bolt/exec/GroupingSet.h" +#include #include #include "bolt/common/base/Exceptions.h" #include "bolt/common/base/SpillConfig.h" @@ -38,6 +39,7 @@ #include "bolt/exec/OperatorUtils.h" #include "bolt/exec/RowToColumnVector.h" #ifdef ENABLE_BOLT_JIT +#include #include "bolt/jit/aggregation/HashAggrJit.h" #endif #include "bolt/type/Type.h" @@ -342,6 +344,10 @@ GroupingSet::GroupingSet( } GroupingSet::~GroupingSet() { +#ifdef ENABLE_BOLT_JIT + // Ensure no background compilation task still references our chunks. + waitForHashAggrJitCompilation(); +#endif if (isGlobal_) { destroyGlobalAggregations(); } @@ -984,8 +990,19 @@ const SelectivityVector& GroupingSet::getSelectivityVector( } #ifdef ENABLE_BOLT_JIT +void GroupingSet::waitForHashAggrJitCompilation() { + for (auto& future : hashAggrJitCompileFutures_) { + if (future.valid()) { + stats_.aggJitCodegenTimeNs += std::move(future).get(); + } + } + hashAggrJitCompileFutures_.clear(); +} + void GroupingSet::maybeCreateHashAggrJitPlan() { - NanosecondTimer codegenTimer(&stats_.aggJitCodegenTimeNs); + // Wait for any background compilation tasks from a previous plan before + // tearing down the chunks they reference. + waitForHashAggrJitCompilation(); hashAggrJitChunks_.clear(); if (!queryConfig_.enableHashAggrJit() || isGlobal_ || ignoreNullKeys_) { LOG(INFO) << "HashAggrJit plan disabled: enableHashAggrJit=" @@ -1018,15 +1035,27 @@ void GroupingSet::maybeCreateHashAggrJitPlan() { currentChunkSlots.clear(); return; } - jit::HashAggrJitChunk chunk(std::move(currentChunkSlots)); - if (chunk.codegen()) { - hashAggrJitChunks_.push_back(std::move(chunk)); - LOG(INFO) << "HashAggrJit formed chunk: " - << hashAggrJitChunks_.back().getDescription(); - } else { - LOG(INFO) << "HashAggrJit chunk codegen failed for chunk " - << chunk.functionName(); - } + // Register the chunk as not-ready and compile it in the background. The + // query thread falls back to the non-JIT path for slots whose chunk is not + // yet ready (see runHashAggrJitAddChunks / runHashAggrJitExtractChunks), + // then switches to JIT once compilation completes. Submitting all chunks + // up front lets the global CPU executor materialize them in parallel. + hashAggrJitChunks_.push_back( + std::make_unique(std::move(currentChunkSlots))); + auto* chunk = hashAggrJitChunks_.back().get(); + LOG(INFO) << "HashAggrJit formed chunk (compiling in background): " + << chunk->getDescription(); + hashAggrJitCompileFutures_.push_back( + folly::via(folly::getGlobalCPUExecutor().get(), [chunk]() -> uint64_t { + const auto start = std::chrono::steady_clock::now(); + if (!chunk->codegen()) { + LOG(INFO) << "HashAggrJit chunk codegen failed for chunk " + << chunk->functionName(); + } + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + })); currentChunkSlots.clear(); currentChunkSlots.reserve(maxFuseWidth); }; @@ -1103,7 +1132,8 @@ void GroupingSet::runHashAggrJitAddChunks( std::vector inputVectors; std::vector inputRuntimePtrs; std::vector newGroupPtrs; - for (auto& chunk : hashAggrJitChunks_) { + for (auto& chunkPtr : hashAggrJitChunks_) { + auto& chunk = *chunkPtr; if (!chunk.isCodegenReady()) { LOG(INFO) << "HashAggrJit chunk is not codegen-ready, skip add: " << chunk.getDescription(); @@ -1233,7 +1263,15 @@ void GroupingSet::runHashAggrJitExtractChunks( std::vector> rowOutputChildPtrs; std::vector resultPtrs; - for (auto& chunk : hashAggrJitChunks_) { + for (auto& chunkPtr : hashAggrJitChunks_) { + auto& chunk = *chunkPtr; + if (!chunk.isCodegenReady()) { + // Background compilation not finished yet; leave these aggregates for the + // non-JIT extract path (jitExtracted stays 0 for their slots). + LOG(INFO) << "HashAggrJit chunk is not codegen-ready, skip extract: " + << chunk.getDescription(); + continue; + } const auto numSlots = chunk.slots().size(); outputRuntimes.assign(numSlots, jit::HashAggrJitOutputRuntime{}); rowOutputChildren.resize(numSlots); diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index ed89d7ee6..e6c0f8a83 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -40,6 +40,7 @@ #include "bolt/exec/TreeOfLosers.h" #include "bolt/exec/VectorHasher.h" #ifdef ENABLE_BOLT_JIT +#include #include "bolt/jit/aggregation/HashAggrJit.h" #endif #include "bolt/vector/DecodedVector.h" @@ -288,6 +289,9 @@ class GroupingSet { #ifdef ENABLE_BOLT_JIT void maybeCreateHashAggrJitPlan(); + // Blocks until all outstanding background JIT compilation tasks finish, so the + // chunks they reference can be safely destroyed or replaced. + void waitForHashAggrJitCompilation(); void runHashAggrJitAddChunks( char** groups, folly::Range newGroups, @@ -462,7 +466,15 @@ class GroupingSet { std::vector> distinctAggregations_; #ifdef ENABLE_BOLT_JIT - std::vector hashAggrJitChunks_; + // unique_ptr gives each chunk a stable address so background compilation + // tasks can safely hold a raw pointer; HashAggrJitChunk is also non-movable + // (holds a std::atomic). Chunks start not-ready and flip ready once their + // background codegen completes. + std::vector> hashAggrJitChunks_; + // Outstanding background JIT compilation tasks for hashAggrJitChunks_. Each + // future returns the chunk's codegen time in nanoseconds so it can be + // aggregated into stats_.aggJitCodegenTimeNs on the query thread. + std::vector> hashAggrJitCompileFutures_; #endif // True if any aggregate accumulator allocates memory outside RowContainer's diff --git a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp index 5805db3de..a016d646d 100644 --- a/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp +++ b/bolt/exec/benchmarks/HashAggrJitBenchmark.cpp @@ -332,10 +332,9 @@ class HashAggrJitBenchmark : public VectorTestBase { testCase->plan = makePlan(rows, aggregates, planKind); testCase->minFuseWidth = minFuseWidth; testCase->maxFuseWidth = maxFuseWidth; - // Warm up both paths so the benchmark compares steady-state execution and - // doesn't charge one-time plan setup / JIT compilation to the first sample. - run(testCase->plan, false, minFuseWidth, maxFuseWidth); - run(testCase->plan, true, minFuseWidth, maxFuseWidth); + // Warm-up intentionally disabled: we want each sample to include one-time + // plan setup / JIT compilation so the async/parallel-compile optimization + // (first-batch latency) is reflected in the measurement. auto* testCasePtr = testCase.get(); folly::addBenchmark(__FILE__, name + "_nojit", [this, testCasePtr]() { run( diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index bfcda02da..afd4bc1d1 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1191,7 +1191,7 @@ bool isHashAggrJitSupportedType(TypeKind kind) { } bool HashAggrJitChunk::codegen() { - if (addDense_) { + if (ready_.load(std::memory_order_acquire)) { return true; } auto* jit = ThrustJITv2::getInstance(); @@ -1231,8 +1231,14 @@ bool HashAggrJitChunk::codegen() { addDenseNoNull_ = reinterpret_cast( module_->getFuncPtr(addNoNullFn)); extract_ = reinterpret_cast(module_->getFuncPtr(extractFn)); - return init_ != nullptr && addDense_ != nullptr && - addDenseNoNull_ != nullptr && extract_ != nullptr; + if (init_ == nullptr || addDense_ == nullptr || addDenseNoNull_ == nullptr || + extract_ == nullptr) { + return false; + } + // Publish all function pointers before flipping ready_ so the query thread + // observing isCodegenReady()==true also sees fully-initialized pointers. + ready_.store(true, std::memory_order_release); + return true; } } // namespace bytedance::bolt::jit diff --git a/bolt/jit/aggregation/HashAggrJit.h b/bolt/jit/aggregation/HashAggrJit.h index 94d477040..148070882 100644 --- a/bolt/jit/aggregation/HashAggrJit.h +++ b/bolt/jit/aggregation/HashAggrJit.h @@ -3,6 +3,7 @@ #ifdef ENABLE_BOLT_JIT #include +#include #include #include #include @@ -281,7 +282,7 @@ class HashAggrJitChunk { bool codegen(); bool isCodegenReady() const { - return addDense_ != nullptr; + return ready_.load(std::memory_order_acquire); } void init(char** newGroups, int32_t numNewGroups) const { @@ -322,6 +323,10 @@ class HashAggrJitChunk { HashAggrJitAddDenseFunc addDense_{nullptr}; HashAggrJitAddDenseFunc addDenseNoNull_{nullptr}; HashAggrJitExtractFunc extract_{nullptr}; + // Published last by codegen() (release) and read by isCodegenReady() + // (acquire). Lets the query thread fall back to non-JIT while background + // compilation is still in progress, then switch to JIT once ready. + std::atomic ready_{false}; }; } // namespace bytedance::bolt::jit From cec31983f854c6b200f2f4519835e0a3a9c7968a Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 17 Jun 2026 20:43:51 +0800 Subject: [PATCH 90/98] remove useless ir dump and function veriry which could reduce jit compile time by 10% --- bolt/jit/aggregation/HashAggrJit.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index afd4bc1d1..c5ae64505 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -35,7 +34,7 @@ void logHashAggrJitFunctionIR( bool hasError) { const auto* function = module.getFunction(functionName); if (function == nullptr) { - LOG(INFO) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey << " stage=" << stage.str() << " function=" << functionName.str() << " error=" << hasError << ": "; return; @@ -44,7 +43,7 @@ void logHashAggrJitFunctionIR( llvm::raw_string_ostream out(ir); function->print(out); out.flush(); - LOG(INFO) << "HashAggrJit generated LLVM IR for chunk " << moduleKey + VLOG(1) << "HashAggrJit generated LLVM IR for chunk " << moduleKey << " stage=" << stage.str() << " function=" << functionName.str() << " error=" << hasError << ":\n" << ir; @@ -908,7 +907,7 @@ bool genInitIR( builder.SetInsertPoint(end); builder.CreateRetVoid(); - return !llvm::verifyFunction(*func, &llvm::errs()); + return true; } bool genAddDenseIR( @@ -998,7 +997,7 @@ bool genAddDenseIR( builder.SetInsertPoint(end); builder.CreateRetVoid(); - return !llvm::verifyFunction(*func, &llvm::errs()); + return true; } bool genExtractIR( @@ -1066,7 +1065,7 @@ bool genExtractIR( builder.SetInsertPoint(end); builder.CreateRetVoid(); - return !llvm::verifyFunction(*func, &llvm::errs()); + return true; } } // namespace From f02eda5d6d74490e26f22ed0c2513807600f44df Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Wed, 17 Jun 2026 21:58:40 +0800 Subject: [PATCH 91/98] remove useless docs --- doc/hash-aggr-jit-todolist.md | 169 -- doc/hash_aggr_jit_state_consistency_review.md | 182 -- doc/hashaggr-jit-benchmark.md | 1439 --------------- doc/hashaggr-jit-code-review.md | 133 -- hashaggr_jit_refactor_plan.md | 1571 ----------------- 5 files changed, 3494 deletions(-) delete mode 100644 doc/hash-aggr-jit-todolist.md delete mode 100644 doc/hash_aggr_jit_state_consistency_review.md delete mode 100644 doc/hashaggr-jit-benchmark.md delete mode 100644 doc/hashaggr-jit-code-review.md delete mode 100644 hashaggr_jit_refactor_plan.md diff --git a/doc/hash-aggr-jit-todolist.md b/doc/hash-aggr-jit-todolist.md deleted file mode 100644 index 027f9bf48..000000000 --- a/doc/hash-aggr-jit-todolist.md +++ /dev/null @@ -1,169 +0,0 @@ -# HashAggr JIT TODO List - -## Resolved(已处理,保留遗留风险备忘) - -### HashAggrJitOps 散布在各 aggregate + Aggregate.h 硬依赖 LLVM 头 - -**问题** -- `Aggregate.h` `#include HashAggrJit.h`,后者 `#include ` 等重头,导致 - 所有 include `Aggregate.h` 的 TU(JIT 开启时)被拖进 LLVM IR 头,编译时间膨胀。 -- 每个 aggregate 子类内嵌一组 `compileHashAggrJit*` static codegen,依赖 IRBuilder; - `DecimalSumAggregate.h` 模板头塞了 ~120 行 codegen,每个实例化点重复展开。 -- runtime helper(`jit_HashAggrSetFlat*` 等)散落在 `RowContainer.cpp`,decimal extract - helper 散落在 `SumAggregate.cpp` / `AverageAggregate.cpp`。 - -**本次处理(四点 + 遗留点,已编译验证)** -1. 剥离 LLVM 头出 `Aggregate.h`: - - 新建 `bolt/jit/aggregation/HashAggrJitTypes.h`(纯 metadata,无 LLVM): - state / decoded&output 描述符 / planContext / enum / `HashAggrJitDescriptor` - (`ops` 持有前向声明的 `HashAggrJitOps*`)/ `HashAggrJitSlot` / 三个自由函数声明 - / `getXxxOps()` 声明。 - - `HashAggrJit.h` 改为 `#include HashAggrJitTypes.h` + 仅保留 codegen-only - (`HashAggrJitOps` / `HashAggrJitExtractTarget` / `HashAggrJitCodegen` / `HashAggrJitChunk`)。 - - `Aggregate.h` 的 include 改为 `HashAggrJitTypes.h`,LLVM 头不再进公共头。 -2. 各 aggregate codegen 迁到 `bolt/jit/aggregation/ops/*Ops.cpp`: - `CountOps / MinMaxOps / SumOps / AvgOps / DecimalSumOps / DecimalAvgOps`,各 `getXxxOps()`; - 编入 `bolt_thrustjit`。aggregate 子类只留 `supportsHashAggrJit` + `createHashAggrJitDescriptor` - (`.ops = jit::getXxxOps()`)。须留类内:MinMax 的虚函数 `jitKind()`、Decimal 的 - `sumType_` / `resultType()` 依赖。 -3. runtime helper 迁到 `bolt/jit/aggregation/runtime/`: - - `HashAggrRuntime.cpp`:`jit_HashAggrResizeVector` / `SetFlat*` / `SetPartialAvgDouble` - (原在 `RowContainer.cpp`)。 - - `HashAggrDecimalRuntime.cpp`(遗留点):`jit_HashAggrExtract{Final,Partial}Decimal{Sum,Avg}` - + `jitDecimalSumComputeFinal`(原在 `SumAggregate.cpp` / `AverageAggregate.cpp`)。 - - 两文件只依赖 vector + `bolt/type/DecimalUtil.h`,编入 `bolt_exec`(同符号空间、 - `ENABLE_EXPORTS`,仍 extern "C" + visibility default,dlsym 可解析)。 -4. 编译验证:`bolt_thrustjit` / `bolt_exec` / `bolt_aggregates` / - `bolt_functions_spark_aggregates` 均通过;nm 确认 `jit_HashAggr*` 12 个符号在新文件 - 以 `T` 导出,旧文件无残留定义。 - -**未做 / 遗留** -- 链接级端到端单测运行验证未做(当前为 release 纯库配置,无可执行 target)。 - 需要时配 `release_with_test` 跑 HashAggr JIT 单测。 -- 已知 `bolt/functions/sparksql/aggregates/CMakeLists.txt` 里 `SumAggregate.cpp` 被列两次 - (历史问题,非本次范围,未改)。 - -### Descriptor ↔ Slot 字段重复 / positional init 易错 - -**问题** -- `HashAggrJitDescriptor` 与 `HashAggrJitSlot` 字段大量重复(slot 仅多 row-layout 字段), - 各 aggregate 以 positional 方式构造 descriptor(连续 bool 靠人眼对位,易出低级 bug), - `createHashAggrJitSlot` 又逐字段 copy boilerplate。 - -**本次处理** -- 建议1:6 处 `createHashAggrJitDescriptor` 全改 C++20 designated initializer - (`.kind=`、`.inputKind=` …),消除 positional 对位风险;未 reorder 字段。 -- 建议2:`HashAggrJitSlot` 改为内嵌 `HashAggrJitDescriptor desc`,只保留 - `aggregateIndex/offset/nullByte/nullMask` + `desc`;`createHashAggrJitSlot` - 缩为 4 字段 + `.desc = descriptor`。IR 端与 `functionName()` 等约 70 处 - `slot.` 统一改成 `slot.desc.`;`offset/nullByte/nullMask` 仍在顶层。 -- 验证:无残留旧式访问、无 `descriptor.desc`、无 `.desc.offset` 误改、无双重 `.desc.desc`、 - `HashAggrJitDescriptor::signature()`(裸字段名)未受影响。未重新编译。 - -### 删除 JIT init 对 `Aggregate::numNulls_` 的同步(commit f74cc21160) - -**背景** -- 旧机制:JIT initGroup 直接写 group 的 null bit,但不碰 `Aggregate::numNulls_`; - 而非 JIT extract 的 `isNull()` 依赖 `numNulls_` 短路(为 0 时直接判非 null)。 -- 为弥合差异,曾引入 `HashAggrJitDescriptor/Slot::initSetsNull` 标志 + - `Aggregate::addNumNulls()`,由 `GroupingSet` 在 JIT init 后手工补账。 -- 该机制最初动机:partial agg 中「add 走 JIT、extract 走非 JIT」时的 null diff。 - 现在 add/extract 均支持 JIT,价值大幅下降,且属跨层补丁、封装差。 - -**本次处理** -- 已删除:`Aggregate::addNumNulls()`、`GroupingSet` 中的 `initSetsNull → addNumNulls` - 补账循环、`HashAggrJitDescriptor/Slot::initSetsNull` 字段、各 aggregate 构造处的 - `/*initSetsNull=*/` 实参。 - -**遗留风险(需后续验证 / 补强)** -- 当前 add/extract 仍是 best-effort,存在静默回落非 JIT 的口子,最典型是 **spill**: - - extract 在 `hasSpilled()` / `supportRowBasedOutput_` 时整体跳过 JIT。 - - encoding 不符预期、distinct/mask/sortingKeys 等也会 fallback。 -- 风险场景:某 slot 用了 JIT add(init 只写 null bit、未维护 `numNulls_`),但运行时 - 回落非 JIT extract → `isNull()` 因 `numNulls_==0` 短路,把「全 null 组」误判为非 null - → 输出 0 而非 null(静默错数据)。守护用例:`hashAggrJitAllNullGroup`。 - -**后续待办(择机)** -- 重点回归:带 spill 的 partial agg(尤其全 null 组)。 -- 选一条强化方向之一: - - 做法 1(plan-time 硬门槛):只有「add + extract 全程 JIT 有保证」的 slot 才进 JIT - init/add,会 fallback 的(含可能 spill)一开始就不走 JIT。语义最干净。 - - 做法 2(fallback 现算):保留 fallback,但在非 JIT extract 入口扫一遍 null bit 重建 - `numNulls_`,spill 场景也安全,改动小。 - -## Pending - -### [P0] JIT add/merge+extract 路径正确性 bug,被 test 链接丢符号长期掩盖 - -**现象** -- 单测 `SumAggregationTest.hashAggrJitMergeAndExtract` 与 - `SumAggregationTest.hashAggrJitAllNullGroup`(均为 partial+final 两阶段、非 decimal) - 在 JIT 路径**真正执行**时结果错误: - - `hashAggrJitAllNullGroup`:group sum 期望 12,得 0。 - - `hashAggrJitMergeAndExtract`:sum/avg/min 全 null、count 全 0,相当于 add 完全没生效。 -- JIT 模块成功编译执行(无 "Symbols not found" / 无 fallback 日志),是**执行结果错**, - 不是回退。 - -**根因定位(已用 git stash 二分确认)** -- 与本轮 decimal IR 化改动**无关**:在干净 HEAD 上、仅加一个把 runtime 符号 - (如 `jit_HashAggrResizeVector`)拉进 test 可执行的 link anchor,这两个用例即 FAIL。 -- 真正背景:commit `4cbfc5e590`(runtime helper 迁出 `RowContainer.cpp` 到独立 .o)后, - 这些 `jit_HashAggr*` 符号**未被 test/可执行链接**(无 C++ 引用,.o 被链接器丢弃)。 - 于是 JIT 在 test 二进制里 materialize 失败 → **静默回退非 JIT** → 结果恰好正确 → - **掩盖了 JIT 路径本身的既有正确性 bug**。 -- 本轮 decimal 改动新增的 link anchor 把这些符号拉回可执行,JIT 路径终于被真正执行, - 从而**暴露**(非引入)该 bug。 - -**潜在影响(需进一步确认)** -- 若生产可执行同样没有引用这些 runtime .o,则 HashAggr JIT 在生产里可能**根本没在跑** - (一直静默回退非 JIT)。需要核实生产链接是否包含这些符号。 -- 一旦修复链接(让 JIT 真正执行),这个 add/merge+extract 正确性 bug 会立刻显现, - 必须在「启用 JIT 执行」之前先修。 - -**后续待办** -- 定位 add/merge+extract 在两阶段非 decimal 场景下结果归零/全 null 的根因 - (疑点:partial extract 与 final merge 的累加器布局 / null 语义,可能与 - commit `f74cc21160` 删除 `numNulls_` 同步相关——`allNullGroup` 正是该语义守护用例)。 -- 当前 decimal 改动保留了 link anchor(benchmark 需要它,否则 JIT 符号缺失); - 注意 anchor 会让上述 bug 在跑相关单测时显现为 FAILED。 - -**⚠️ 合入注意** -- 本轮 decimal IR 改动保留了 link anchor,启用后 JIT 路径会真正执行,导致 - `hashAggrJitMergeAndExtract` / `hashAggrJitAllNullGroup` 两个单测**变红(FAILED)**。 -- 这不是 decimal 改动引入的回归,而是上述既有 bug 被暴露;但**合入前必须先修该 P0 bug, - 否则 CI 会红**。两个选项: - 1. 先修 add/merge+extract 正确性 bug,再合入(推荐)。 - 2. 临时移除 link anchor —— 但那样 benchmark 里 JIT 符号又会解析失败、JIT 回退, - decimal 性能改善无法体现。 -- 简言之:**link anchor + 既有 bug 是绑定的**,要么一起修好,要么都先不动。 - -### [P2] chunk 同时 codegen `add_dense` 和 `add_dense_no_null`,编译时间与产物 ×2 - -**现状** -- 每个 chunk 在 `compile()` 里生成两份 add 函数,仅 `checkInputNulls` 不同: - - `bolt/jit/aggregation/HashAggrJit.cpp:1281-1282` -- 两者差异 100% 在 `genAddDenseIR` 内的 null-check 分支: - - `bolt/jit/aggregation/HashAggrJit.cpp:1016-1029`、`bolt/jit/aggregation/HashAggrJit.cpp:1040` -- 运行时按 batch 级 `inputsMayHaveNulls` 选函数指针,batch 内 stable。 - -**评估结论** -- 问题真实:codegen 时间 ~×2。 -- 但**非 P0**:编译是 per-chunk 一次性、结果缓存在 `module_`/`addDense_`/`addDenseNoNull_` - (`bolt/jit/aggregation/HashAggrJit.cpp:1301-1304`),运行热路径只调用其中一个函数, - 不存在运行期代码膨胀。影响的是编译延迟,不是执行性能。建议定级 **P2**。 - -**为什么 pending** -- 是否值得改,取决于生产实际 workload,目前未知。 - -**决策需要的数据** -- JIT 编译耗时占比 / chunk 编译次数。 -- `inputsMayHaveNulls == false` 的 batch 实际占比。 - -**候选方案** -- 维持现状:若编译耗时占比可忽略,不改。 -- 推荐(建议2,lazy):默认只编 `add_dense`,仅当出现 `inputsMayHaveNulls == false` - 的 batch 时再 lazy 编 `add_dense_no_null`;未就绪前 fallback 到 `add_dense` - (对 no-null 输入同样正确,仅损失少量性能)→ 砍掉常见场景一半编译量,零正确性风险。 -- 不推荐(建议1,运行期 i1 参数):会让 no-null 热路径丢失编译期 dead-branch 消除,反而变慢。 -- 高成本(建议3,alwaysinline + wrapper):理论最优但需重写 add codegen 结构, - 回归面大,仅为省一次性编译,性价比低。 diff --git a/doc/hash_aggr_jit_state_consistency_review.md b/doc/hash_aggr_jit_state_consistency_review.md deleted file mode 100644 index 873d81a32..000000000 --- a/doc/hash_aggr_jit_state_consistency_review.md +++ /dev/null @@ -1,182 +0,0 @@ -# Bolt HashAggrJit — JIT vs 非 JIT 中间状态一致性审计 - -> 审计分支:`origin/hash_aggr_jit`(原审计 HEAD: `2b6de6e186 remove useless code`) -> 比对范围:`bolt/jit/aggregation/`、`bolt/exec/GroupingSet.{cpp,h}`、`bolt/functions/{lib,sparksql,prestosql}/aggregates/` -> 审计目标:所有支持 JIT 的聚合算子,在所有支持的输入参数类型下,JIT 与非 JIT 的中间状态字节布局与运行时语义是否完全等价。 - -> **【复核更新 @ 当前工作区】** 本文档原始审计基于 `2b6de6e186`,下文已就地标注每条结论在当前工作区的状态: -> - **B3 已解决(更优方式)**:decimal/avg 的 JIT state 现为 `using` 别名直接指向非 JIT 继承的同一 POD 布局基类(`DecimalAccumulatorLayout.h` / `SumCount.h`),布局已是**同一类型**而非镜像,cross-assert 已多余。 -> - **B6 已过时**:Int128/Bool extract 已支持、`canCompileMinMaxExtract` 及整个 `CanExtractFn` 已删除。 -> - **B1/B2 已解决**:decimal sum/avg extract 已在 codegen 期按实际输出精度选择 short/long 专用 runtime helper,runtime 内不再保留 `longDecimal` 分支;`precision/scale` 与 `auxPrecision/auxScale` 已统一为 intermediate/final 语义。 -> - **B4 已解决(最小闭环)**:decimal sum/avg merge partial row 的 sum 字段读取 kind 已改为按 `precision` 推导,不再复用原始输入列 `inputKind`。 -> - **B5/B7 已确认无问题**:B5 的 NaN 排序语义已确认 JIT/非 JIT 一致;B7 的 Spark 整数 sum 溢出语义也已确认一致,均无需继续处理。 - ---- - -## 结论一句话 -累加器**字节布局**已通过单一权威 POD 布局对齐;`decimal_sum` / `decimal_avg` 的短/长 decimal extract 与 partial merge 位宽判断也已修复;B5/B7 经复核确认不是问题。当前不再有已知 JIT/非 JIT 状态不一致待修项。 - ---- - -## A. 字节级一致性表 - -| 算子 | JIT 结构 | 非 JIT 结构 | 布局 | 等价? | -|---|---|---|---|---| -| `avg` | `AvgAccumulatorLayout{double sum; i64 count}` `AvgOps.cpp:16` | `SumCount` `AverageAggregateBase.h:81` | 16B | ✅ **【已修】** JIT 端 `using AvgAccumulatorLayout = SumCount`,同一类型 | -| `count(*/col)` | 单 `i64` `CountOps.cpp:18` | `sizeof(i64)` `CountAggregate.cpp:49` | 8B | ✅ | -| `sum` (int/float) | `TAccumulator` `SumOps.cpp:14` | `TAccumulator` `SumAggregateBase.h:78` | 同 | ✅ | -| `min/max` (非 i128) | `T` `MinMaxOps.cpp:14` | `T` `MinMaxAggregates.cpp:93` | 同 | ✅ | -| `decimal sum` | `JitDecimalSumState{i128 sum; i64 overflow; bool isEmpty}` `HashAggrJitDecimalState.h:16` | `DecimalSum{i128 sum; i64 overflow; bool isEmpty}` `DecimalSumAggregate.h:37` | 32B | ✅ **【已修】** JIT `using` 别名指向 `DecimalSumAccumulatorLayout`,`DecimalSum` 继承之,单一权威布局 | -| `decimal avg` | `JitDecimalAvgState{i128 sum; i64 count; i64 overflow}` `HashAggrJitDecimalState.h:22` | `LongDecimalWithOverflowState` 字段同 `DecimalAggregate.h:45` | 32B | ✅ **【已修】** 同上,`using` → `LongDecimalWithOverflowLayout`,继承复用 | - ---- - -## B. 真实差异(按严重度) - -### ✅ B1. `decimal_sum` partial/final extract 硬编码 `int128` 输出(已解决) - -> **【复核 @ 当前工作区:已解决】** `emitDecimalSumExtract` 已在 codegen 期按实际输出 decimal 精度选择专用 helper:partial 看 `precision`,final 看 `auxPrecision`。短 decimal 调 `jit_HashAggrExtract*ShortDecimalSum` 写 `FlatVector`,长 decimal 调 `jit_HashAggrExtract*LongDecimalSum` 写 `FlatVector`;decimal sum 的 descriptor 中 `auxPrecision/auxScale` 与 `precision/scale` 镜像为同一个 sum type。 - -历史旧代码(已删除)曾在 `HashAggrDecimalRuntime.cpp` 中硬编码: -```cpp -vector->as>() // final: 直接吃 raw vector,不看 longDecimal 参数 -rowVector->childAt(0)->asFlatVector() // partial 同上 -``` - -- 调用方 `emitDecimalSumExtract` 不再传 `longDecimal`;short/long 选择已下沉为不同 runtime symbol。 -- `canCompileDecimalSumExtract` (`DecimalSumOps.cpp:118`) 无条件 `return true`,没有任何回退门。 -- Spark 注册签名 `r_precision = min(38, a_precision + 10)`(`SumAggregate.cpp:117`),factory 在 `sumType->isShortDecimal()` 时显式构造 `DecimalSumAggregate`(`SumAggregate.cpp:158`)。 -- 结果:`sum(DECIMAL(5,2))` / `sum(DECIMAL(8,3))` 的结果列就是 `FlatVector`,JIT 那里 `dynamic_cast>` 得到 `nullptr` → 空指针 set/setNull → **段错误 / 堆破坏**。 -- `GroupingSet.cpp:1304-1308` 的注释已显式承认 *"decimal avg's accumulatorKind is Int128 while its final result is a short decimal (FlatVector)"*——但那只保护 scalar output 自动按 vector 真实 type 推 kind 的路径;走 runtime helper 自己再 cast 一次的路径完全没保护到。 - -### ✅ B2. `decimal_avg` partial extract 同病(已解决) - -> **【复核 @ 当前工作区:已解决】** `emitDecimalAvgExtract` 已在 codegen 期选择 short/long 专用 runtime helper:partial 看中间 sum 精度 `precision`,final 看结果精度 `auxPrecision`。`jit_HashAggrExtractPartial*DecimalAvg` 与 final 路径一样,分别写短/长 decimal sum child。 - -历史问题:`HashAggrDecimalRuntime.cpp` 曾硬写 `asFlatVector`。Spark AVG 第二条签名 `ROW(DECIMAL(a_precision, a_scale), BIGINT)` 会沿用入参精度——短 decimal 时 partial 输出是 `int64` sum vector,旧实现会 crash。当前已通过 short/long 专用 helper 消除该风险,且 runtime 内无无效分支。 - -### ✅ B3. JIT/非 JIT 结构没有跨层 `static_assert`(已解决) - -> **【复核 @ `b4b99b5553`:已解决(更优方式),无需再加 static_assert】** 现已抽出零依赖 POD 布局基类(`DecimalAccumulatorLayout.h` 的 `DecimalSumAccumulatorLayout`/`LongDecimalWithOverflowLayout`、`SumCount.h` 的 `SumCount`):非 JIT 结构 `DecimalSum`/`LongDecimalWithOverflowState` **继承**之,JIT 端 `JitDecimalSumState`/`JitDecimalAvgState`/`AvgAccumulatorLayout` 用 `using` **别名同一基类**。两侧已是**同一类型**而非镜像副本,sizeof/offsetof 必然相等,cross-assert 已多余;各处保留 `static_assert(is_standard_layout_v)` 防止派生类误加数据成员破坏布局。 - -历史建议曾要求在 `HashAggrJitDecimalState.h` 加跨层 `sizeof/offsetof` 断言: -```cpp -static_assert(sizeof(JitDecimalSumState) == sizeof(sparksql::DecimalSum)); -static_assert(offsetof(JitDecimalSumState, sum) == offsetof(sparksql::DecimalSum, sum)); -// overflow / isEmpty 同理; -// JitDecimalAvgState vs LongDecimalWithOverflowState 同理; -// AvgAccumulatorLayout vs SumCount 同理。 -``` -当前已通过共享 POD 布局基类让两侧复用同一类型,跨层 `sizeof/offsetof` 断言不再是必须项。 - -### ✅ B4. row 输入 stride 仍按 plan 端 `slot.desc.inputKind`(已解决) - -> **【复核 @ 当前工作区:已解决】** 对 decimal sum/avg 的 partial merge,row field 0 是中间 sum decimal,真实位宽由中间精度 `precision` 决定,而不是原始输入列 `slot.desc.inputKind`。当前 `DecimalSumOps.cpp` / `DecimalAvgOps.cpp` 均通过 `decimalKindForPrecision(slot.desc.precision)` 读取 row field 并 cast 到 accumulator `Int128`。 - -历史问题:`DecimalSumOps.cpp` / `DecimalAvgOps.cpp` 读 row field 曾用 `slot.desc.inputKind`;runtime `fillHashAggrJitRowInputRuntime` 又按 vector 真实类型再反推一次。两侧不一致时 stride 错。当前在不扩 descriptor 的前提下,先利用已有 `precision` 作为 codegen 期 single source of truth,消除了 decimal partial row 的位宽漂移。 - -### ✅ B5. `MinAggregate` 初值不同(已确认无问题) - -> **【最终确认:无问题,无需继续处理】** 已确认 JIT == 非 JIT,Spark/Presto 下语义一致;补防回归测试属于可选工程加固,不再作为待办项。 -> 注意前提:Spark 与 Presto 的 min/max **共用同一份 `registerMinMax` + 同一个 `MinAggregate`/`MaxAggregate` + 同一份 JIT op**,所以确实需要验,但结论是一致的。 -> - 非 JIT 权威比较是 `SimpleVector::comparePrimitiveAsc`(`SimpleVector.h:368-380`):**NaN 视为最大**(NaN 排在所有非 NaN 之后),且该语义**不随 `SPARK_COMPATIBLE` 改变**——Spark/Presto 统一。 -> - JIT op(`MinMaxOps.cpp:45-59`)逐组合等价于 NaN=最大: -> - Min:`(oldIsNan && !valueIsNan) || (!valueIsNan && old>value)` → 避开 NaN,仅全 NaN 时结果为 NaN; -> - Max:`!oldIsNan && (valueIsNan || old - 对 {NaN,非NaN} 全部四种组合手工核对,结果与 `comparePrimitiveAsc` 完全一致。 -> - 初值 `0.0`(JIT)/ `NaN`(非 JIT Presto)都不参与比较:首条非 null 输入必定 `nullState=true` 无条件覆盖初值(`shouldStore = nullState || better`)。 -> - **结论**:B5 不是 bug,语义在 Spark 与 Presto 下均一致。无需修复。 - -- 非 JIT Presto: `kInitialValue_ = NaN` (`MinMaxAggregates.cpp:367-371`) -- JIT: 统一写 `0.0` (`MinMaxOps.cpp:21`);靠 `shouldStore = nullState || better` 让第一条非 null 输入无条件覆盖 -- 已确认所有 NaN/非 NaN 组合下 JIT 与非 JIT 一致;Spark/Presto 共用的比较语义与 JIT op 匹配。 - -### 🟢 B6. `MinMax` 混合路径 - -> **【复核 @ `b4b99b5553`:已过时】** Int128/Bool 的 extract 已实现,`canCompileMinMaxExtract` 及整个 `CanExtractFn` 已删除,extract 不再走非 JIT 混合路径。B6 描述的现象不复存在。null 槽布局一致性的 NOTE 注释建议仍可保留参考。 - -`canCompileMinMaxExtract` (`MinMaxOps.cpp:74-79`) 对 `Int128`/`Bool` 返回 `false`,extract 走非 JIT,init/update 走 JIT。当前 `slot.nullByte/nullMask` 来自 `RowContainer::nullByte/Mask`(`GroupingSet.cpp:766`),和 `exec::Aggregate::isNull` 一致——OK,但建议加一条 NOTE 注释防止后续重构改 null 槽布局踩坑。 - -### ✅ B7. 整数 `sum` 溢出(已确认无问题) - -> **【最终确认:无问题,无需继续处理】** Spark 语义下不是 bug,JIT/非 JIT 结论一致。 -> - 走 JIT 的整数 sum **只有 Spark**:Presto 的 sum 未注册 `supportsHashAggrJit`(prestosql 下仅 Count/MinMax 接入 JIT),因此不存在"Presto sum 复用 JIT"的实际路径。注意 sum 与 min/max 不同——min/max 是 Spark/Presto 共用注册,sum 各自独立注册(Spark 有自己的 `registerSumAggregate`)。 -> - 非 JIT Spark sum:`setSumAggOverflowCheckFlag(false)`(`SumAggregate.cpp:224`)→ `Overflow=true` 分支 → 静默回绕。 -> - JIT Spark sum:整数走 `CreateAdd`(`SumOps.cpp:46-47`)→ 静默回绕。 -> - **结论**:两者完全一致(都静默回绕),Spark 下无差异;不需要加额外修复。 - -- 非 JIT 默认 `CHECK_ADD` 抛异常;Spark 在 `registerSumAggregate` 显式 `setSumAggOverflowCheckFlag(false)` → `Overflow=true` → 静默回绕(`SumAggregateBase.h:190-197`、`SumAggregate.cpp:222`)。 -- JIT 永远 `CreateAdd` 静默回绕(`SumOps.cpp:46-47`)。 -- 当前只有 Spark 注册了 `supportsHashAggrJit`,**结论一致**,无需继续处理。 - ---- - -## C. 输入类型覆盖矩阵(按算子) - -| 算子 | 非 JIT 支持 | JIT 支持 | 不一致点 | -|---|---|---|---| -| `avg` | numeric + short/long decimal | numeric only(decimal 走另一条) | OK,gate 正确(`AverageAggregate.cpp:53-54` raw decimal 显式 false) | -| `count(*)/(col)` | 全部 | numeric + short/long decimal + hugeint | OK | -| `sum` | 同 avg | 非 decimal numeric + hugeint | OK | -| `min/max` | 全部 | numeric + short/long decimal + hugeint | OK;long decimal extract 走非 JIT (B6) | -| `decimal_sum` | short/long decimal raw + `ROW` intermediate | 同左 | ✅ B1/B4 已修 | -| `decimal_avg` | short/long decimal raw + `ROW` intermediate | 同左 | ✅ B2/B4 已修 | - ---- - -## D. 最终判定(Verdict) - -> **【复核 @ 当前工作区】** 下表反映原审计;当前状态:累加器布局一致性已升级为"单一权威类型"(B3 解决);B1/B2 的 Partial/Final extract 短 decimal 崩溃已修;B4 的 decimal partial merge row-field kind 漂移已修;B5/B7 已确认无问题。当前无已知一致性待修项。 - -| 维度 | 一致? | -|---|---| -| Per-group 累加器结构体字节布局 | ✅ 全部一致 | -| 初值 + null bit 语义 | ✅ 一致(MinMax NaN 排序已确认一致) | -| Update 单点累加语义 | ✅ 一致(Spark 整数 sum 在 `Overflow=true` 下也一致) | -| Merge intermediate 语义 | ✅ B4 已修:decimal sum/avg partial row sum 字段按 `precision` 推导 kind | -| Partial extract → ROW 输出 | ✅ B1/B2 已修:decimal sum/avg 在 codegen 期选择 short/long 专用 helper | -| Final extract → 标量输出 | ✅ B1 已修:decimal sum final 在 codegen 期选择 short/long 专用 helper | -| 类型覆盖矩阵 | ✅ decimal 短/长结果不再依赖回退规避 crash | - ---- - -## E. 最小修复清单(按优先级) - -> **【复核 @ 当前工作区:当前优先级总览】** -> - **已解决**:B1 / B2 —— runtime helper 短 decimal 崩溃;B4 —— decimal partial row 输入 stride 漂移。 -> - **已确认无问题**:B5 / B7 —— JIT/非 JIT 语义一致,无需继续处理。 -> - **已解决/过时(无需再做)**:B3(已用单一权威布局根除)、B6(canExtract 已删、Int128/Bool extract 已支持)。 - -1. **B1/B2 修复(已完成)** - `emitDecimalSumExtract` / `emitDecimalAvgExtract` 已按 partial/final 的实际输出精度选择 short/long 专用 runtime helper,避免把 `longDecimal` 作为外部 C++ helper 参数导致 runtime 内保留无效分支。 - -2. **B3 修复(已完成)** - 已通过 `DecimalAccumulatorLayout.h` / `SumCount.h` 抽出共享 POD 布局基类,JIT 与非 JIT 复用同一权威布局类型,无需再补跨层 `sizeof/offsetof` 断言。 - -3. **B4 修复(已完成最小闭环)** - decimal partial merge 的 sum 字段已按 `precision` 推导为 `Int64/Int128`,并用该 kind 做 row-field read 和 cast。后续若要泛化到所有 ROW 字段,可再考虑 `HashAggrJitDescriptor.rowInputFields[i].kind`。 - -4. **B5(已确认无问题)** - MinMax NaN 排序已确认 JIT/非 JIT 一致,无需修复。 - -5. **B7(已确认无问题)** - Spark 整数 sum 溢出语义已确认 JIT/非 JIT 一致,无需修复。 - ---- - -## 附录:关键文件一览 - -| 路径 | 作用 | -|---|---| -| `bolt/jit/aggregation/HashAggrJit.{h,cpp}` | JIT 主框架、IR codegen、runtime 装载 | -| `bolt/jit/aggregation/HashAggrJitTypes.h` | `HashAggrJitDescriptor` / `HashAggrJitSlot` / 输入输出 runtime 结构体 | -| `bolt/jit/aggregation/HashAggrJitDecimalState.h` | `JitDecimalSumState` / `JitDecimalAvgState`(已通过共享 POD 布局基类消除镜像漂移) | -| `bolt/jit/aggregation/ops/{Avg,Count,Sum,MinMax,DecimalSum,DecimalAvg}Ops.cpp` | 各算子的 init/update/merge/extract 编译规则 | -| `bolt/jit/aggregation/runtime/HashAggrDecimalRuntime.cpp` | decimal sum/avg extract 的 C++ 运行时 helper(B1/B2 已修) | -| `bolt/exec/GroupingSet.cpp` | JIT chunk 调度、runtime fill、回退判断 | -| `bolt/functions/sparksql/aggregates/DecimalSumAggregate.h` | `DecimalSum` 非 JIT 结构 + `supportsHashAggrJit` | -| `bolt/functions/lib/aggregates/DecimalAggregate.h` | `LongDecimalWithOverflowState` 非 JIT 结构 | -| `bolt/functions/lib/aggregates/AverageAggregateBase.h` | `SumCount` 非 JIT 结构 | -| `bolt/functions/lib/aggregates/SumAggregateBase.h` | 整数 sum `CHECK_ADD` 与全局 `Overflow` flag | -| `bolt/functions/sparksql/aggregates/{SumAggregate,AverageAggregate}.cpp` | Spark sum/avg 注册 + JIT gate | -| `bolt/functions/prestosql/aggregates/{MinMaxAggregates,CountAggregate}.cpp` | Presto MinMax/Count 注册 + JIT gate | diff --git a/doc/hashaggr-jit-benchmark.md b/doc/hashaggr-jit-benchmark.md deleted file mode 100644 index c556114cf..000000000 --- a/doc/hashaggr-jit-benchmark.md +++ /dev/null @@ -1,1439 +0,0 @@ -# HashAggr JIT 性能评测报告 - -## 1. 测试环境与方法 - -- **构建**:`Release` + spark 开关(`spark_compatible=True / enable_testutil=True / - skip_test=False`),对齐 `make release_spark_with_test`;benchmark 单独 - `BOLT_BUILD_BENCHMARKS=ON`,未启用 `enable_perf`(gperftools 源码下载超时,folly - benchmark 不依赖它)。 -- **benchmark**:`bolt/exec/benchmarks/HashAggrJitBenchmark.cpp`,目标 - `bolt_hashaggr_jit_benchmark`。覆盖 sum/avg/min/count(width 4/8/16/32)、 - merge(partial+final)、decimal sum/avg(当前按 `PartialFinal` 路径评测)、 - double min/max、partial extract。 -- **数据规模**:每用例 20 batch × 10000 行。 -- **关键控制**: - - JIT 模块为进程级 LRU 全局缓存,预热后**每个 JIT 函数仅编译一次**(已用 VLOG 验证 - 每个函数名 compile 次数 = 1),编译开销不计入迭代。 - - 两条路径都先 warm-up 再计时;热路径调试日志默认静默(已降级为 `VLOG(1)`)。 - - speedup = nojit / jit,**> 1 表示 JIT 更快**。 - -运行命令: - -```bash -# 低基数(聚合计算密集) -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --hashaggr_jit_benchmark_batches=20 --hashaggr_jit_benchmark_groups=100 - -# 高基数(哈希探测密集) -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --hashaggr_jit_benchmark_batches=20 --hashaggr_jit_benchmark_groups=10000 -``` - -## 2. 低基数结果(groups=100) - -| 聚合 | width4 | width8 | width16 | width32 | -|------|--------|--------|---------|---------| -| **count**(single) | **1.14x** | **1.27x** | **1.25x** | **1.33x** | -| **count**(merge) | **1.15x** | **1.23x** | **1.19x** | **1.34x** | -| sum(single) | 0.46x | 0.38x | 0.38x | 0.34x | -| sum(merge) | 0.47x | 0.40x | 0.37x | 0.37x | -| avg(single) | 0.52x | 0.46x | 0.44x | 0.45x | -| avg(merge) | 0.54x | 0.49x | 0.45x | 0.45x | -| min(single) | 0.49x | 0.42x | 0.38x | 0.39x | -| min(merge) | 0.50x | 0.43x | 0.39x | 0.40x | - -其他(width8):decimal_sum **0.40x** · decimal_avg **0.75x** · double_min 0.57x · -double_max 0.55x · partial_avg_extract 0.82x · partial_sum_extract 0.84x - -> 38 个用例中:JIT 更快 8 个(全部是 count),更慢 30 个。groups=10 对照趋势一致(误差 < 5%)。 - -## 3. 高基数结果(groups=10000) - -| 聚合 | width4 | width8 | width16 | width32 | -|------|--------|--------|---------|---------| -| **count**(single) | 1.08x | **1.44x** | **1.59x** | **1.60x** | -| **count**(merge) | 0.94x | **1.14x** | **1.22x** | **1.29x** | -| sum(single) | 0.66x | 0.71x | 0.74x | 0.71x | -| sum(merge) | 0.68x | 0.71x | 0.72x | 0.73x | -| avg(single) | 0.80x | 0.79x | 0.84x | 0.80x | -| avg(merge) | 0.52x | 0.49x | 0.51x | 0.52x | -| min(single) | 0.62x | 0.57x | 0.62x | 0.66x | -| min(merge) | 0.68x | 0.60x | 0.64x | 0.64x | - -其他(width8):decimal_sum 0.88x · decimal_avg 0.88x · double_min 0.75x · -double_max 0.63x · partial_avg_extract 0.74x · partial_sum_extract 0.84x - -## 4. 关键发现 - -1. **只有 count 稳定正收益**:低基数 1.14–1.34x,高基数最高 1.60x,且随 fuse 宽度增大而提升。 - count 的 accumulator 最简单,融合循环省下的逐聚合函数调用/分支开销占主导。 -2. **sum/avg/min/max/decimal 在 JIT 下更慢,且基数越低越慢**:sum 从高基数 0.71x 跌到 - 低基数 0.34–0.46x。 -3. **瓶颈在 JIT add 逐行路径,而非哈希探测**:groups=10 与 groups=100 的 JIT 绝对耗时 - 几乎相同(如 width8_sum jit ≈ 5.0ms 两者一致),说明耗时与组数无关、只与行数相关—— - 即**每行 add 成本** JIT 高于非 JIT 的向量化路径。这正是“低基数本应让 JIT 更受益”的 - 预期被反转的根本原因。 -4. **decimal_avg(0.75x) 曾优于 decimal_sum(0.40x)**:这组历史数据采集时,decimal_avg - final 仍走非 JIT(Spark rescale 复杂逻辑),因此拖累相对较小。当前已补齐 final - decimal avg extract JIT helper,新的 decimal_avg 结果需以 `PartialFinal` 基准重新观察。 - -## 5. 结论与建议 - -- **现状**:HashAggr JIT 当前仅对 **count** 类(轻 accumulator、宽融合)有明确收益; - sum/avg/min/max/decimal 的 JIT 计算路径尚慢于现有向量化实现,**不建议默认开启**这些 - 聚合的 JIT。 -- **根因**:JIT add 内核的输入读取退化为**逐行外部 C 函数调用**,丧失了内联与向量化 - (详见第 6 章 perf 定位)。 -- **后续可做**: - 1. 把输入读取从 `jit_GetDecodedValue*` 外部调用改为 **JIT 内联**(直接对 flat/identity - 映射的 raw buffer 做 GEP+load),让 LLVM 能向量化取值-累加循环; - 2. 按聚合类型设白名单(先只对 count 默认启用 JIT); - 3. 对 avg-merge 的重路径专门优化。 - -## 6. perf 定位:sum/avg add 内核瓶颈 - -环境:`perf`(linux-tools-5.15)+ `perf_event_paranoid=1`,`-F 2999 --call-graph dwarf`, -对 `width16_sum`(fuse=16,groups=100)single-aggregation 的 JIT / 非 JIT 两条路径分别采样。 - -### 6.1 热点符号对比(self time,同一工作负载) - -| 项 | JIT 路径 | 非 JIT 路径 | -|----|----------|-------------| -| 输入取值 | `jit_GetDecodedValueI64`(外部调用,逐行)**45.4%** | `jit_GetDecodedValueI64` 仅 **0.45%** | -| 累加内核 | `[JIT]` 匿名生成码合计 **~25%** | `SumAggregateBase::addRawInput`(内联模板)**52.0%** | -| 哈希探测 | `arrayGroupProbe` 0.9% | `arrayGroupProbe` 2.2% | - -### 6.2 根因分析 - -JIT add 内核的逐行循环(`HashAggrJit.cpp:685` 的 `genAddDenseIR`)对**每行每列**都生成一次 -`CreateCall(jit_GetDecodedValueI64, {decoded, row})`(取值封装见 `loadDecodedValue` -`HashAggrJit.cpp:428`、helper 实现见 `RowContainer.cpp:1724`)。其代价: - -1. **不可内联的跨边界调用**:每行付出 call/ret + 调用约定下的寄存器溢出; -2. `DecodedVector::valueAt(index)` 内部还要判断 identity-mapping、做 indices 间接寻址; -3. **阻断向量化**:取值-累加循环因夹着 opaque 外部调用,LLVM 无法做 SIMD/循环展开。 - -而非 JIT 路径走 `SumAggregateBase::addRawInput`,整批输入在编译期类型已知、`DecodedVector` -raw buffer 被**内联顺序读取**并可向量化,因此 `jit_GetDecodedValueI64` 在该路径几乎不出现(0.45%)。 - -**count 为何不受影响**:count(`countStar` 或仅计数)不读取输入值,在 add 内核里跳过取值与 -null 检查(`HashAggrJit.cpp:697`),故没有 `jit_GetDecodedValue*` 开销,融合循环的省调用收益得以体现。 - -### 6.3 优化方向 - -最高优先级是**消除逐行外部取值调用**:对 flat / identity-mapping 的输入,在 JIT 内核里直接拿到 -`DecodedVector` 的 `data()` 基址,用 `GEP + load`(dictionary 映射则内联 indices 间接寻址)替换 -`jit_GetDecodedValue*` call,使整段取值-累加可被 LLVM 向量化——预期能把 sum/avg/min/max 的 JIT -路径从当前 0.4–0.8x 拉回到 ≥1x。 - -## 7. Direct Decoded Descriptor 优化验证 - -### 7.1 优化内容 - -本轮优化按第 6.3 节方向实现:`GroupingSet` 在每个 batch 为每个聚合输入准备一个轻量 descriptor, -JIT add 内核不再对 raw 单值输入逐行调用 `jit_GetDecodedValue*` / `jit_GetDecodedIsNull`,而是在 IR -内直接读取: - -1. `values`:decoded 后 base vector 的 raw values 基址; -2. `indices`:top-level row -> base row 的映射。flat 为 identity mapping,dictionary / constant 也由 - `DecodedVector` 统一展开为同一套映射; -3. `nulls`:top-level row null bitmap。若为 null,IR 直接跳过 null check; -4. `decodedVector`:保留原始 `DecodedVector*`,仅用于 intermediate ROW merge 的 row-field helper。 - -这样同一段 JIT IR 可以同时覆盖 flat / dictionary / constant 输入编码,不需要按 batch encoding 重新 -codegen;热循环中的普通数值读取变为 `index = indices[row]` + `values[index]`。 - -### 7.2 对比方法 - -为了衡量本次优化本身的收益,分别构建并运行了两版同一 benchmark: - -- **baseline/helper-call 版本**:原实现,每行每列调用 `jit_GetDecodedValue*` / `jit_GetDecodedIsNull`; -- **optimized/direct-descriptor 版本**:当前实现,IR 内直接读取 descriptor。 - -运行命令: - -```bash -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=5 \ - --bm_max_secs=3 \ - --bm_regex='(width8_(sum|avg|min|count|double_min|double_max)_(nojit|jit)|width8_decimal_(sum|avg)_(nojit|jit))' -``` - -测试数据规模仍为默认 20 batch × 10000 行,groups=10000,width=8。 - -### 7.3 JIT 自身优化收益 - -下表只比较两版 JIT 路径: - -| case | helper-call JIT | direct-descriptor JIT | JIT 优化收益 | -|------|----------------:|----------------------:|-------------:| -| width8_sum | 6.78ms | 3.94ms | **快 41.9%** | -| width8_avg | 6.90ms | 4.69ms | **快 32.0%** | -| width8_min | 6.71ms | 4.05ms | **快 39.6%** | -| width8_count | 3.00ms | 2.96ms | 快 1.3% | -| width8_decimal_sum | 13.27ms | 9.52ms | **快 28.3%** | -| width8_decimal_avg | 18.08ms | 14.47ms | **快 20.0%** | -| width8_double_min | 6.77ms | 4.81ms | **快 29.0%** | -| width8_double_max | 6.77ms | 4.55ms | **快 32.8%** | - -结论:消除 `jit_GetDecodedValue*` 外部 helper call 后,所有需要读取输入值的聚合都有明显收益, -幅度约 **20%–42%**。`count` 基本不读取 input value,因此收益很小,符合预期。 - -### 7.4 优化后 JIT vs no-JIT - -下表比较当前 direct-descriptor JIT 与非 JIT 路径: - -| case | no-JIT | direct-descriptor JIT | speedup = nojit / jit | -|------|-------:|----------------------:|----------------------:| -| width8_sum | 4.64ms | 3.94ms | **1.18x** | -| width8_avg | 5.29ms | 4.69ms | **1.13x** | -| width8_min | 3.75ms | 4.05ms | 0.93x | -| width8_count | 4.30ms | 2.96ms | **1.45x** | -| width8_decimal_sum | 11.96ms | 9.52ms | **1.26x** | -| width8_decimal_avg | 16.09ms | 14.47ms | **1.11x** | -| width8_double_min | 5.06ms | 4.81ms | **1.05x** | -| width8_double_max | 4.21ms | 4.55ms | 0.93x | - -优化前,sum / avg / min / decimal / double min/max 的 JIT 路径大多慢于 no-JIT;优化后,sum、avg、 -count、decimal_sum、decimal_avg、double_min 已经变为正收益。`width8_min` 和 `width8_double_max` 仍略慢, -后续需要继续看 min/max accumulator null 判断、compare 分支以及 NaN 处理逻辑。 - -### 7.5 更新后的结论 - -- 第 6 章定位的主要瓶颈(逐行不可内联 helper call)已被验证:去掉 helper call 后,需要读输入值的 - JIT 聚合普遍获得 **20%–42%** 的 JIT 内核收益。 -- HashAggr JIT 不再只有 count 有收益;在当前 width8 / groups=10000 场景下,sum、avg、decimal、 - double_min 也已经超过 no-JIT。 -- 剩余负收益集中在 min/max 类 case,下一步优化重点应转向比较更新逻辑本身,而不是 decoded value 读取。 - -## 8. Direct Descriptor 后的最新 perf 定位 - -### 8.1 为什么仍远低于 multi_sum POC 预期 - -multi_sum POC 的核心收益假设是:把 `sum(c1)..sum(cN)` 合并后,可以显著减少重复的 group/hash lookup, -并让 `NumArgs` 在编译期已知,从而获得 loop unrolling。该假设对 POC 成立,但和当前 Bolt -HashAggregation 的生产路径并不完全等价。 - -当前 `GroupingSet::addInputForActiveRows` 中,hash/group probe 在所有 aggregate 之前统一执行一次: -`prepareForGroupProbe` / `groupProbe` 先生成 `groups = lookup_->hits.data()`,之后才进入聚合函数循环 -或 JIT chunk 执行。因此 **no-JIT 的多个 separate sums 已经共享同一次 hash lookup**,JIT 并不能像 POC -那样再节省 7 次或 15 次 hash lookup。JIT 当前主要节省的是每个 aggregate 独立 `addRawInput` 的函数 -调度、decoded 读取和多次遍历 rows 的开销。 - -另外,benchmark 使用 `AssertQueryBuilder(...).copyResults()` 测的是完整查询路径,不是纯 add kernel: -它还包含 input hash/vector encoding、RowContainer 新 group 初始化、结果 extract、结果 RowVector copy、 -task/benchmark 框架等共同开销。消除 `jit_GetDecodedValue*` 后,add kernel 已明显变快,但完整查询的 -Amdahl 上限被这些公共开销压低。 - -相关代码位置: - -- 单次 group probe:`bolt/exec/GroupingSet.cpp:344` -- probe 后统一进入 JIT chunk / aggregate function add:`bolt/exec/GroupingSet.cpp:375` -- JIT chunk 执行后 no-JIT aggregate 被跳过:`bolt/exec/GroupingSet.cpp:381` - -### 8.2 perf 方法 - -由于当前机器上硬件 counter(cycles/cache-misses/L1/dTLB 等)不可用,本轮使用 software -`cpu-clock` 采样;`/usr/bin/perf` wrapper 找不到匹配 5.4 内核的 perf binary,实际使用 -`/usr/lib/linux-tools-5.15.0-160/perf`。 - -为了减少 benchmark 初始化和 warm-up 对结果的污染,使用较大的 `--bm_min_iters`,让被测 case 的计时 -阶段占主导。代表性命令: - -```bash -/usr/lib/linux-tools-5.15.0-160/perf record -F 2999 \ - -o /tmp/bolt-width16-sum-jit-long.perf.data -- \ - ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=1500 --bm_max_secs=30 --bm_regex='^width16_sum_jit$' - -/usr/lib/linux-tools-5.15.0-160/perf record -F 2999 \ - -o /tmp/bolt-width16-sum-nojit-long.perf.data -- \ - ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=1500 --bm_max_secs=30 --bm_regex='^width16_sum_nojit$' -``` - -### 8.3 最新 sum benchmark 结果 - -在 direct-descriptor 优化后,sum 的完整查询收益如下: - -| case | no-JIT | JIT | speedup = nojit / jit | -|------|-------:|----:|----------------------:| -| width4_sum | 2.85ms | 2.73ms | **1.04x** | -| width8_sum | 4.68ms | 4.15ms | **1.13x** | -| width16_sum | 8.92ms | 7.48ms | **1.19x** | -| width32_sum | 17.05ms | 14.84ms | **1.15x** | - -可以看到趋势与 POC 一致:中等宽度有收益;但收益幅度只有约 4%–19%,远低于 POC 中 8/16 列 -约 42%–43% 的提升。根因是当前生产路径的 no-JIT baseline 已共享 hash probe,且完整查询包含较多 -JIT 无法消除的公共开销。 - -### 8.4 width8_sum perf 热点 - -长时间采样下的 self-time 分类如下(`cpu-clock` samples): - -| 分类 | JIT | no-JIT | 说明 | -|------|----:|-------:|------| -| add kernel | **32.72%**(JIT generated add_dense) | **45.48%**(`SumAggregateBase::addRawInput`) | JIT add 已明显少于 no-JIT add | -| hash/vector encoding | 4.15% | 4.67% | 双方共同开销 | -| hash probe | 2.27% | 2.33% | 双方都只 probe 一次,JIT 不再有 POC 中的“省多次 lookup”收益 | -| result/input copy | 3.55% | 2.74% | 完整 `copyResults` 路径成本 | -| RowContainer new/store/init | 1.52% | 2.44% | 新 group / row storage 成本 | -| JIT extract setter | **4.35%** | 0.10% | JIT extract 仍调用 `jit_HashAggrSetFlatI64` helper | -| dynamic_cast/type dispatch | **6.62%** | 0.21% | JIT 路径额外的结果/类型处理开销 | - -top symbols 中 JIT 路径最大热点已经从 `jit_GetDecodedValueI64` 迁移到 `[JIT]` 生成码本身; -`jit_GetDecodedValue*` 不再是主热点,说明第 7 章的 direct descriptor 已生效。 - -### 8.5 width16_sum perf 热点 - -width16 下趋势更明显: - -| 分类 | JIT | no-JIT | 说明 | -|------|----:|-------:|------| -| add kernel | **35.51%**(JIT generated add_dense) | **53.87%**(`SumAggregateBase::addRawInput`) | JIT kernel 节省明显 | -| aggregate init | 0.02% | 4.99% | JIT fused init 基本消除了 per-aggregate init 热点 | -| hash/vector encoding | 1.69% | 2.46% | 共同开销 | -| hash probe | 1.90% | 1.28% | 共同开销;采样误差下同量级 | -| result/input copy | 4.24% | 3.82% | 完整查询共同成本 | -| JIT extract setter | **6.37%** | 0.08% | JIT extract helper 成为新热点之一 | -| dynamic_cast/type dispatch | **7.02%** | 0.17% | JIT 路径额外成本,抵消一部分 add kernel 收益 | - -从绝对耗时估算,JIT add kernel 已从 no-JIT 的约 4.9ms(`8.92ms * 53.87%`)降到约 2.7ms -(`7.48ms * 35.51%`)。也就是说 add kernel 本身接近 **1.8x**,但完整查询最终只有 **1.19x**, -因为剩余时间被 probe、encoding、output materialization、copy 和 JIT extract helper/type dispatch 稀释。 - -### 8.6 最新瓶颈排序 - -1. **JIT 已无法再通过“少做 hash lookup”获得 POC 级收益**:Bolt baseline 本身已经 one probe for all - aggregates,这是和 multi_sum POC 最大的结构差异。 -2. **JIT add_dense 生成码仍是最大热点**:direct descriptor 去掉 helper call 后,热点回到真正的 - scalar RMW 聚合内核。当前每个 slot 每行仍要做 `indices[row]`、`values[index]`、accumulator null bit - clear、old accumulator load、add、store;这些操作围绕 `groups[row]` 间接指针,LLVM 很难 SIMD 化。 -3. **JIT extract 仍有 helper/type-dispatch 开销**:`jit_HashAggrSetFlatI64` 和 `__dynamic_cast` 在 JIT - 路径合计约 10%–13%,这是 direct descriptor 后的新显性瓶颈。no-JIT extract 使用 aggregate 自身的 - typed extract,开销低得多。 -4. **完整 query benchmark 的公共成本很高**:hash/vector encoding、RowContainer、result copy、task 框架等 - 不随 add kernel 优化而下降,限制最终端到端 speedup。 - -### 8.7 后续调优建议 - -1. **优化 JIT extract**:像 add path 一样为 extract 也传入 output descriptor(raw values/nulls),在 IR - 里直接写 FlatVector buffer,替换 `jit_HashAggrSetFlatI64` helper,并尽量避免 `dynamic_cast`。 -2. **增加 flat/no-null 快路径**:当前为了同时支持 flat/dictionary/constant,所有输入都走 `indices[row]`。 - 可以保持同一份 IR 兼容多 encoding,但在 loop preheader 根据 descriptor 判断 `indices` 是否 identity, - 分支到 flat 直读 `values[row]` 的 loop;dictionary/constant 再走 mapped loop。这样不需要按 batch - encoding 重新 codegen,但可以让常见 flat case 少一次 indices load。 -3. **消除 no-null 场景下的 per-row accumulator null clear**:sum 当前每行每 slot 都执行 - `clearAccumulatorNull`。对于 input 确认无 null 的 batch,可以考虑 batch-level 或 new-group-level 地清 - accumulator null,避免每行重复写 null bitmap。 -4. **区分纯 add kernel benchmark 与完整 query benchmark**:POC 结论更接近 add kernel 层收益;生产端到端 - 收益需要单独扣除 hash probe/output/copy 等公共成本。后续 benchmark 可以补一个只测 `GroupingSet` add - 的 microbenchmark,避免 `copyResults` 稀释定位。 -5. **继续限制 fuse width 的甜点区间**:当前 width16/32 仍有收益,但并未出现 POC 的巨大收益。考虑先保持 - `maxFuseWidth=16` 或最多 32;更宽时需要结合 cache/TLB 数据重新评估。 - -## 9. JIT extract raw output descriptor 优化验证 - -### 9.1 优化内容 - -本轮继续优化第 8.6 节定位出的 extract 瓶颈:JIT extract 不再对普通 FLAT primitive 输出逐行调用 -`jit_HashAggrSetFlat*` helper,而是由 `GroupingSet` 为每个 aggregate output 准备 -`HashAggrJitOutput` descriptor: - -1. `values`:`FlatVector::mutableRawValues()`; -2. `nulls`:`BaseVector::mutableRawNulls()`; -3. `vector`:原始 `BaseVector*`,保留给 decimal / partial avg ROW 等复杂输出 helper fallback。 - -JIT extract IR 对 `Int8/Int16/Int32/Int64/Float/Double` 直接执行: - -```text -values[row] = value -isNull ? clear null bitmap bit : set null bitmap bit -``` - -`Bool`、`Int128/decimal`、partial avg ROW output 暂不做 raw 写,仍通过 descriptor 中的 `vector` 走原 helper。 - -### 9.2 功能与性能验证命令 - -构建: - -```bash -cmake --build --preset conan-release --target bolt_hashaggr_jit_benchmark --parallel 2 -``` - -功能覆盖(sum/avg/min/count/decimal/double min-max): - -```bash -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=3 --bm_max_secs=2 \ - --bm_regex='^width8_(sum|avg|min|count|double_min|double_max|decimal_sum|decimal_avg)_(nojit|jit)$' -``` - -sum 宽度扫描: - -```bash -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=20 --bm_max_secs=5 \ - --bm_regex='^width(4|8|16|32)_sum_(nojit|jit)$' -``` - - -### 9.3 最新 width8 结果 - -| case | no-JIT | JIT | speedup = nojit / jit | -|------|-------:|----:|----------------------:| -| width8_sum | 4.61ms | 3.55ms | **1.30x** | -| width8_avg | 5.41ms | 4.18ms | **1.29x** | -| width8_min | 3.71ms | 3.48ms | **1.07x** | -| width8_count | 4.30ms | 2.41ms | **1.78x** | -| width8_decimal_sum | 12.13ms | 9.58ms | **1.27x** | -| width8_decimal_avg | 16.49ms | 14.77ms | **1.12x** | -| width8_double_min | 4.94ms | 4.23ms | **1.17x** | -| width8_double_max | 4.21ms | 3.81ms | **1.10x** | - -对比第 7.4 节,`min` / `double_max` 已从略慢于 no-JIT 变为正收益;`sum`、`avg`、`count` 也继续提升。 - -### 9.4 最新 sum 宽度扫描 - -| case | no-JIT | JIT | speedup = nojit / jit | -|------|-------:|----:|----------------------:| -| width4_sum | 2.60ms | 2.44ms | **1.07x** | -| width8_sum | 4.65ms | 3.45ms | **1.35x** | -| width16_sum | 9.06ms | 6.06ms | **1.50x** | -| width32_sum | 17.28ms | 12.21ms | **1.42x** | - -相比第 8.3 节(extract 优化前 width16_sum 约 1.19x、width32_sum 约 1.15x),raw output descriptor 后 -宽聚合收益明显扩大,说明之前 extract helper/type-dispatch 确实抵消了大量 add_dense 的融合收益。 - - -### 9.5 详细数据 - -``` -$ ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark -============================================================================ -[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s -============================================================================ -width4_sum_nojit 2.57ms 388.70 -width4_sum_jit 2.40ms 416.48 ----------------------------------------------------------------------------- -width4_avg_nojit 3.27ms 306.05 -width4_avg_jit 2.59ms 385.86 ----------------------------------------------------------------------------- -width4_min_nojit 2.40ms 417.30 -width4_min_jit 2.42ms 413.22 ----------------------------------------------------------------------------- -width4_count_nojit 2.42ms 413.03 -width4_count_jit 1.90ms 525.58 ----------------------------------------------------------------------------- -width4_merge_sum_nojit 3.60ms 277.57 -width4_merge_sum_jit 3.29ms 303.98 ----------------------------------------------------------------------------- -width4_merge_avg_nojit 4.58ms 218.27 -width4_merge_avg_jit 7.19ms 138.99 ----------------------------------------------------------------------------- -width4_merge_min_nojit 3.42ms 292.16 -width4_merge_min_jit 3.33ms 300.62 ----------------------------------------------------------------------------- -width4_merge_count_nojit 3.37ms 296.86 -width4_merge_count_jit 2.88ms 346.72 ----------------------------------------------------------------------------- -width8_sum_nojit 4.62ms 216.54 -width8_sum_jit 3.36ms 297.77 ----------------------------------------------------------------------------- -width8_avg_nojit 5.37ms 186.22 -width8_avg_jit 4.14ms 241.70 ----------------------------------------------------------------------------- -width8_min_nojit 3.70ms 270.01 -width8_min_jit 3.50ms 285.84 ----------------------------------------------------------------------------- -width8_count_nojit 4.26ms 235.01 -width8_count_jit 2.35ms 425.31 ----------------------------------------------------------------------------- -width8_merge_sum_nojit 6.18ms 161.79 -width8_merge_sum_jit 4.60ms 217.60 ----------------------------------------------------------------------------- -width8_merge_avg_nojit 7.70ms 129.85 -width8_merge_avg_jit 12.31ms 81.22 ----------------------------------------------------------------------------- -width8_merge_min_nojit 5.31ms 188.20 -width8_merge_min_jit 4.83ms 206.90 ----------------------------------------------------------------------------- -width8_merge_count_nojit 5.72ms 174.70 -width8_merge_count_jit 3.58ms 279.62 ----------------------------------------------------------------------------- -width16_sum_nojit 9.01ms 110.95 -width16_sum_jit 5.93ms 168.53 ----------------------------------------------------------------------------- -width16_avg_nojit 10.53ms 94.95 -width16_avg_jit 7.38ms 135.53 ----------------------------------------------------------------------------- -width16_min_nojit 7.92ms 126.22 -width16_min_jit 6.24ms 160.20 ----------------------------------------------------------------------------- -width16_count_nojit 7.73ms 129.35 -width16_count_jit 3.50ms 285.49 ----------------------------------------------------------------------------- -width16_merge_sum_nojit 11.44ms 87.44 -width16_merge_sum_jit 7.58ms 131.87 ----------------------------------------------------------------------------- -width16_merge_avg_nojit 15.68ms 63.79 -width16_merge_avg_jit 23.72ms 42.16 ----------------------------------------------------------------------------- -width16_merge_min_nojit 10.21ms 97.95 -width16_merge_min_jit 7.94ms 125.98 ----------------------------------------------------------------------------- -width16_merge_count_nojit 10.10ms 98.97 -width16_merge_count_jit 5.22ms 191.41 ----------------------------------------------------------------------------- -width32_sum_nojit 17.20ms 58.13 -width32_sum_jit 12.08ms 82.76 ----------------------------------------------------------------------------- -width32_avg_nojit 19.42ms 51.48 -width32_avg_jit 15.11ms 66.20 ----------------------------------------------------------------------------- -width32_min_nojit 15.56ms 64.26 -width32_min_jit 12.53ms 79.78 ----------------------------------------------------------------------------- -width32_count_nojit 15.66ms 63.85 -width32_count_jit 7.12ms 140.37 ----------------------------------------------------------------------------- -width32_merge_sum_nojit 23.30ms 42.91 -width32_merge_sum_jit 16.24ms 61.59 ----------------------------------------------------------------------------- -width32_merge_avg_nojit 30.22ms 33.09 -width32_merge_avg_jit 47.82ms 20.91 ----------------------------------------------------------------------------- -width32_merge_min_nojit 19.79ms 50.52 -width32_merge_min_jit 15.78ms 63.37 ----------------------------------------------------------------------------- -width32_merge_count_nojit 19.32ms 51.75 -width32_merge_count_jit 10.17ms 98.30 ----------------------------------------------------------------------------- -width8_decimal_sum_nojit 12.03ms 83.13 -width8_decimal_sum_jit 9.83ms 101.71 ----------------------------------------------------------------------------- -width8_decimal_avg_nojit 16.29ms 61.38 -width8_decimal_avg_jit 14.77ms 67.70 ----------------------------------------------------------------------------- -width8_double_min_nojit 5.05ms 197.94 -width8_double_min_jit 4.16ms 240.16 ----------------------------------------------------------------------------- -width8_double_max_nojit 4.18ms 239.18 -width8_double_max_jit 3.84ms 260.33 ----------------------------------------------------------------------------- -width8_high_card_partial_avg_extract_nojit 61.78ms 16.19 -width8_high_card_partial_avg_extract_jit 80.29ms 12.46 ----------------------------------------------------------------------------- -width8_high_card_partial_sum_extract_nojit 27.36ms 36.54 -width8_high_card_partial_sum_extract_jit 23.51ms 42.54 ----------------------------------------------------------------------------- -``` - - - -### 9.6 当前剩余瓶颈分析 - -从完整 benchmark 结果看,direct decoded input descriptor 和 raw output descriptor 已经解决了此前最明显的 -两类 helper 开销:`jit_GetDecodedValue*` 输入读取 helper,以及 `jit_HashAggrSetFlat*` / `dynamic_cast` -输出写 helper。普通 FLAT primitive 聚合现在基本都已经转为正收益,但仍有几类结构性瓶颈。 - -#### 9.6.1 最大负收益:merge avg 的 ROW intermediate 路径 - -目前最明显的回退集中在 `merge_avg_jit`: - -| case | no-JIT | JIT | speedup = nojit / jit | -|------|-------:|----:|----------------------:| -| width4_merge_avg | 4.58ms | 7.19ms | **0.64x** | -| width8_merge_avg | 7.70ms | 12.31ms | **0.63x** | -| width16_merge_avg | 15.68ms | 23.72ms | **0.66x** | -| width32_merge_avg | 30.22ms | 47.82ms | **0.63x** | - -这个比例在不同 width 下非常稳定,说明不是 benchmark 噪声,而是路径本身还没有被优化。根因是 avg merge -的 intermediate input 是 `ROW(sum, count)`,没有完全吃到 raw decoded descriptor 优化:普通数值输入已经能 -通过 `values + indices + nulls` 在 JIT IR 中直接 load,但 ROW field 读取仍然需要类似 -`jit_GetDecodedRowFieldDouble` / `jit_GetDecodedRowFieldI64` / `jit_GetDecodedRowFieldIsNull` 的 helper 或 -DecodedVector row-field 路径。 - -partial avg extract 也印证了这个结论: - -| case | no-JIT | JIT | speedup = nojit / jit | -|------|-------:|----:|----------------------:| -| width8_high_card_partial_avg_extract | 61.78ms | 80.29ms | **0.77x** | -| width8_high_card_partial_sum_extract | 27.36ms | 23.51ms | **1.16x** | - -partial sum 输出是 FLAT,已经受益于 raw output descriptor;partial avg 输出是 ROW,目前仍走 helper fallback, -因此仍然更慢。 - -**建议**:短期可考虑禁用 `merge_avg_jit` 和 `partial_avg_extract_jit`;长期需要为 ROW input/output 增加 -descriptor,把 `sum` / `count` 两个 child vector 的 raw values/nulls 直接传给 JIT。 - -#### 9.6.2 主路径瓶颈:JIT add_dense 仍是 row-based scalar RMW loop - -普通 sum/avg/count 已经有明显正收益: - -| 聚合 | width4 | width8 | width16 | width32 | -|------|-------:|-------:|--------:|--------:| -| sum | 1.07x | 1.38x | 1.52x | 1.42x | -| avg | 1.26x | 1.30x | 1.43x | 1.29x | -| count | 1.27x | 1.81x | 2.21x | 2.20x | - -count 收益显著高于 sum/avg,因为 count 不需要读取 input value,也不需要做加法以外的复杂状态维护。sum/avg -的剩余成本主要回到 JIT 生成码本身: - -```text -group = groups[row] -index = indices[row] -value = values[index] -load accumulator -clear accumulator null bit -add / update count -store accumulator -``` - -这仍然是 row-based scalar read-modify-write loop:`groups[row]` 是间接指针访问,`indices[row]` 即使在 flat -input 下也要额外 load,accumulator 存在 RowContainer row storage 中而不是连续 columnar buffer,LLVM 很难 -做 SIMD 化。 - -**建议**:优先做 flat/no-null add_dense 快路径。在 input 是 flat identity mapping 且没有 null 时,直接生成 -`value = values[row]`,跳过 `indices[row]` 和 input null 分支。 - -#### 9.6.3 小 width 收益有限:公共固定成本占比高 - -width4 下收益明显弱于 width8/16/32: - -| case | speedup | -|------|--------:| -| width4_sum | 1.07x | -| width4_min | 0.99x | -| width4_count | 1.27x | - -width4 中可 fusion 的 aggregate 数量少,JIT 能省下的 per-aggregate dispatch / loop traversal 不多,但 descriptor -准备、JIT chunk 调用、result vector resize、output materialization、RowContainer / hash probe / copyResults 等 -完整 query 公共成本仍然存在。 - -**建议**:默认启用策略上应更偏向 width8+ 或 count/sum 这类收益稳定的 case;低 width case 需要结合实际 -query 成本谨慎启用。 - -#### 9.6.4 min/max 收益较小:compare 与 null-init 分支仍偏重 - -min/max 已经转为正收益,但弱于 sum/count: - -| case | speedup | -|------|--------:| -| width8_min | 1.06x | -| width16_min | 1.27x | -| width32_min | 1.24x | -| width8_double_min | 1.21x | -| width8_double_max | 1.09x | - -min/max 每行更新不仅要读取 input value,还要处理 accumulator 是否 null、首次 non-null 初始化、compare 分支; -double min/max 还可能受 NaN / ordering 语义影响。相比 sum 的简单加法,这些分支更难被 LLVM 优化。 - -**建议**:后续可为 no-null + accumulator initialized 场景生成更简单的 compare-only 快路径。 - -#### 9.6.5 decimal 仍受复杂 overflow/precision 逻辑限制 - -decimal 现在已经是正收益,但幅度有限: - -| case | no-JIT | JIT | speedup = nojit / jit | -|------|-------:|----:|----------------------:| -| width8_decimal_sum | 12.03ms | 9.83ms | **1.22x** | -| width8_decimal_avg | 16.29ms | 14.77ms | **1.10x** | - -decimal update/extract 仍包含 int128 accumulator、overflow state、precision/scale 检查、final extract overflow -处理以及 decimal avg rescale 等复杂逻辑,无法像 primitive sum 一样完全变成简单 raw load/store。 - -**建议**:decimal 可以继续专项优化,但优先级低于 ROW avg 路径和 flat/no-null add_dense 快路径。 - -#### 9.6.6 当前瓶颈优先级 - -1. **P0:ROW avg 路径**:`merge_avg_jit` 和 `partial_avg_extract_jit` 是目前唯一大幅负收益路径。短期禁用, - 长期做 ROW input/output descriptor。 -2. **P1:flat/no-null add_dense 快路径**:减少 `indices[row]` 间接读取和 null 分支,继续提升 sum/avg/min 主路径。 -3. **P2:减少 per-row accumulator null clear**:对于 no-null input 或 accumulator 已初始化场景,把 null clear 从 - per-row 下沉到更粗粒度。 -4. **P3:min/max compare 快路径**:减少 accumulator null/init 分支。 -5. **P4:decimal 专项优化**:拆解 overflow/precision helper,但收益优先级相对靠后。 - -### 9.7 perf 验证 - -代表性命令: - -```bash -/usr/lib/linux-tools-5.15.0-160/perf record -F 2999 \ - -o /tmp/bolt-width16-sum-jit-outputdesc.perf.data -- \ - ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=200 --bm_max_secs=8 --bm_regex='^width16_sum_jit$' - -/usr/lib/linux-tools-5.15.0-160/perf report \ - -i /tmp/bolt-width16-sum-jit-outputdesc.perf.data \ - --stdio --no-children --sort symbol --percent-limit 0 \ - | grep -E 'jit_HashAggrSetFlatI64|dynamic_cast|__dynamic|__do_dyncast|HashAggrSetFlat' -``` - -结果:`jit_HashAggrSetFlatI64` / `HashAggrSetFlat*` 不再出现在 perf report 中;`__dynamic_cast` 降到 -约 **0.27%**,`__do_dyncast` 合计约 **0.15%**。对比第 8.5 节,extract helper 与 dynamic_cast/type -dispatch 从 JIT 路径约 **13%** 的显性热点降为噪声级别。 - -### 9.8 更新后的结论 - -- direct decoded input descriptor 解决了 add_dense 的外部取值 helper;raw output descriptor 继续解决了 - extract 的 per-row setter helper / dynamic_cast。 -- 当前 width8 常见 FLAT primitive 数值聚合已全部为正收益;sum 宽度扫描在 width16 达到约 **1.52x**,更接近 - 最初 multi_sum POC 的方向性预期。 -- 截至 raw output descriptor 阶段,最大遗留问题是 ROW intermediate/output:`merge_avg_jit` 和 - `partial_avg_extract_jit` 仍显著慢于 no-JIT;第 10 章继续更新了 ROW descriptor 优化后的最新结果。 -- 主路径剩余瓶颈回到真正的 JIT add_dense 生成码、hash/vector encoding、RowContainer 和 result copy 等公共成本; - 后续若继续优化,优先考虑 flat/no-null add_dense 快路径、减少 per-row accumulator null clear,以及拆分纯 - `GroupingSet` add microbenchmark 来单独观察 kernel 收益。 - -## 10. ROW avg input/output descriptor 优化验证 - -### 10.1 优化内容 - -针对第 9.6.1 节的 P0 瓶颈,本轮为 avg 的 `ROW(sum, count)` intermediate input/output 增加了 raw descriptor: - -1. `HashAggrJitDecodedInput` 增加 `rowField0Values/nulls`、`rowField1Values/nulls`,用于 avg merge 直接读取 - partial 输出的 `sum` / `count` child FlatVector; -2. `HashAggrJitOutput` 增加同名 row field 指针,用于 partial avg extract 直接写 `sum` / `count` child FlatVector; -3. `loadDecodedRowField` / `isDecodedRowFieldNull` 对 field 0/1 走 `GEP + load` / raw null bitmap,避免逐行 - `DecodedVector` ROW field helper; -4. `emitPartialAvgResult` 在存在 row field raw output 时直接写 child values 和 row null bitmap,保留 helper fallback - 以覆盖非预期编码。 - -这个优化利用了 partial avg 的数据流约束:`addIntermediateResults` 的输入来自 `extractAccumulator`,而 -`extractAccumulator` 输出的 ROW child 均为 FLAT,因此 avg merge 拆 ROW 时只需要支持 child FlatVector 快路径。 - -关键代码位置: - -- ROW input/output descriptor 字段:`bolt/jit/aggregation/HashAggrJit.h:61` -- ROW field raw load/null check:`bolt/jit/aggregation/HashAggrJit.cpp:698` / `bolt/jit/aggregation/HashAggrJit.cpp:723` -- partial avg ROW raw output 写入:`bolt/jit/aggregation/HashAggrJit.cpp:795` -- avg merge ROW child raw input 填充:`bolt/exec/GroupingSet.cpp:120` -- partial avg ROW child raw output 填充:`bolt/exec/GroupingSet.cpp:148` -- JIT extract output descriptor 准备:`bolt/exec/GroupingSet.cpp:1194` - -### 10.2 验证命令 - -构建: - -```bash -cmake --build --preset conan-release --target bolt_hashaggr_jit_benchmark --parallel 1 -``` - -完整 benchmark: - -```bash -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_min_iters=5 -``` - -P0 专项复测: - -```bash -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_regex='width(4|8|16|32)_merge_avg' --bm_min_iters=20 - -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_regex='width8_high_card_partial_avg_extract' --bm_min_iters=50 --bm_max_secs=10 -``` - -### 10.3 merge_avg 修复结果 - -ROW input descriptor 后,`merge_avg_jit` 从此前稳定 0.63–0.66x 的最大负收益路径,变为稳定正收益: - -| case | 修复前 speedup | 修复后 no-JIT | 修复后 JIT | 修复后 speedup | -|------|---------------:|--------------:|-----------:|---------------:| -| width4_merge_avg | 0.64x | 4.63ms | 4.02ms | **1.15x** | -| width8_merge_avg | 0.63x | 7.86ms | 5.88ms | **1.34x** | -| width16_merge_avg | 0.66x | 14.83ms | 10.56ms | **1.40x** | -| width32_merge_avg | 0.63x | 31.13ms | 24.21ms | **1.29x** | - -完整 benchmark 的同类结果也保持正收益:width4/8/16/32 分别约 **1.23x / 1.30x / 1.41x / 1.37x**。 - -### 10.4 partial_avg_extract 结果 - -partial avg extract 的 ROW output helper 已被 raw child 写入替换,较第 9.5 节中 80.29ms 的 JIT 路径有明显改善; -但在当前完整查询 benchmark 中,端到端仍有波动且长跑仍略慢于 no-JIT: - -| case | 第 9.5 节 JIT | 修复后 no-JIT | 修复后 JIT | 修复后 speedup | -|------|--------------:|--------------:|-----------:|---------------:| -| width8_high_card_partial_avg_extract | 80.29ms | 60.17ms | 68.19ms | 0.88x | - -对照同一轮 partial sum extract: - -| case | no-JIT | JIT | speedup | -|------|-------:|----:|--------:| -| width8_high_card_partial_sum_extract | 27.50ms | 22.50ms | **1.22x** | - -因此,本轮对 partial avg extract 的结论是:ROW output helper 瓶颈已被削弱,但该 case 的端到端性能还没有稳定转正。 -剩余成本大概率不再只是 ROW child 写出,而是 high-cardinality 场景下每行新 group 初始化、avg accumulator -`sum+count` 更新、RowVector 输出物化以及完整 `copyResults` 公共成本共同导致。 - -### 10.5 完整 benchmark 快照 - -完整 benchmark 命令: - -```bash -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_min_iters=5 -``` - -本轮完整结果的 speedup 汇总如下(speedup = no-JIT / JIT,**> 1 表示 JIT 更快**): - -| 聚合 | width4 | width8 | width16 | width32 | -|------|-------:|-------:|--------:|--------:| -| sum(single) | **1.04x** | **1.34x** | **1.48x** | **1.47x** | -| avg(single) | **1.22x** | **1.27x** | **1.45x** | **1.30x** | -| min(single) | **1.01x** | **1.07x** | **1.27x** | **1.27x** | -| count(single) | **1.28x** | **1.77x** | **2.20x** | **2.28x** | -| sum(merge) | **1.08x** | **1.34x** | **1.49x** | **1.43x** | -| avg(merge) | **1.23x** | **1.30x** | **1.41x** | **1.37x** | -| min(merge) | **1.03x** | **1.11x** | **1.28x** | **1.33x** | -| count(merge) | **1.17x** | **1.57x** | **1.86x** | **1.97x** | - -其他 width8 用例: - -| case | no-JIT | JIT | speedup | -|------|-------:|----:|--------:| -| width8_decimal_sum | 12.05ms | 9.60ms | **1.26x** | -| width8_decimal_avg | 16.15ms | 14.60ms | **1.11x** | -| width8_double_min | 4.95ms | 4.15ms | **1.19x** | -| width8_double_max | 4.13ms | 3.86ms | **1.07x** | -| width8_high_card_partial_avg_extract | 58.15ms | 63.36ms | 0.92x | -| width8_high_card_partial_sum_extract | 24.48ms | 21.34ms | **1.15x** | - -结论:ROW avg input/output descriptor 之后,完整 benchmark 中除 `partial_avg_extract` 外,当前覆盖的主要 -single / merge primitive 聚合均已转为正收益;`count` 和宽 `sum/avg/merge_avg` 收益最稳定。 - -### 10.6 更新后的瓶颈优先级 - -1. **P0 已基本解决:merge_avg ROW input**。`merge_avg_jit` 已从 0.63–0.66x 拉升到 1.15–1.40x,是本轮最主要收益。 -2. **P1:partial_avg_extract 仍需继续拆解**。ROW output raw descriptor 已降低 JIT 绝对耗时,但端到端仍约 0.88x; - 下一步需要用 perf 区分 add/update、新 group 初始化、ROW output materialization 和 `copyResults` 的占比。 -3. **P2:flat/no-null add_dense 快路径**。普通 sum/avg/min 主路径仍有 `indices[row]` 和 per-row null 处理成本。 -4. **P3:减少 per-row accumulator null clear**。对于 no-null input 或 accumulator 已初始化场景,把 null clear 从 per-row - 下沉到更粗粒度。 -5. **P4:min/max compare 和 decimal 专项优化**。收益优先级低于 partial avg extract 和 add_dense 主路径。 - -### 10.7 本轮结论 - -- avg merge 的 ROW intermediate input 已吃到 raw descriptor 优化,最大负收益 case 已转正。 -- partial avg extract 的 ROW output helper 已优化,但 benchmark 仍显示端到端略慢,需要继续 perf 定位剩余成本。 -- HashAggr JIT 当前更适合 sum/count/avg merge 这类宽融合场景;partial avg extract 暂不应作为默认开启 JIT 的依据。 - -## 11. P1:partial_avg_extract 火焰图定位 - -### 11.1 perf 采集与火焰图生成方法 - -对 `width8_high_card_partial_avg_extract` 的 JIT / no-JIT 两条路径分别采样,并生成火焰图。 - -采样(`-F 999 --call-graph dwarf`): - -```bash -# JIT 路径 -/usr/lib/linux-tools-5.15.0-160/perf record -F 999 --call-graph dwarf \ - -o /tmp/bolt-partial-avg-extract-jit.perf.data -- \ - ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=200 --bm_max_secs=20 \ - --bm_regex='^width8_high_card_partial_avg_extract_jit$' - -# no-JIT 路径 -/usr/lib/linux-tools-5.15.0-160/perf record -F 999 --call-graph dwarf \ - -o /tmp/bolt-partial-avg-extract-nojit.perf.data -- \ - ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark \ - --bm_min_iters=200 --bm_max_secs=20 \ - --bm_regex='^width8_high_card_partial_avg_extract_nojit$' -``` - -折叠栈并用 FlameGraph 生成 SVG: - -```bash -FG=/data00/home/liyang.127/FlameGraph - -/usr/lib/linux-tools-5.15.0-160/perf script -i /tmp/bolt-partial-avg-extract-jit.perf.data \ - | $FG/stackcollapse-perf.pl > /tmp/partial-avg-extract-jit.folded -$FG/flamegraph.pl --title "width8_high_card_partial_avg_extract JIT" \ - /tmp/partial-avg-extract-jit.folded \ - > doc/hashaggr-jit-partial-avg-extract-jit-flamegraph.svg - -/usr/lib/linux-tools-5.15.0-160/perf script -i /tmp/bolt-partial-avg-extract-nojit.perf.data \ - | $FG/stackcollapse-perf.pl > /tmp/partial-avg-extract-nojit.folded -$FG/flamegraph.pl --title "width8_high_card_partial_avg_extract no-JIT" \ - /tmp/partial-avg-extract-nojit.folded \ - > doc/hashaggr-jit-partial-avg-extract-nojit-flamegraph.svg -``` - -> 选用 `--call-graph dwarf` 而非 fp/lbr:Release 二进制开启 `-fomit-frame-pointer`,帧指针回溯会断栈; -> 该机器硬件 PMU/LBR 也不可用(只能用 software `cpu-clock`),dwarf 基于 `.eh_frame`/CFI 回溯, -> 在优化过且含 JIT 匿名段的二进制上能还原完整调用栈,适合做火焰图,代价是数据量大、采样频率需调低到 999。 - -产物火焰图: - -- `doc/hashaggr-jit-partial-avg-extract-jit-flamegraph.svg` -- `doc/hashaggr-jit-partial-avg-extract-nojit-flamegraph.svg` - -### 11.2 采样结构与噪声剥离 - -本次火焰图有两个需要先剥离的结构性噪声,否则会误判热点: - -1. **JIT 后台编译线程**:JIT 编译发生在 `CPUThreadPool0` 上的 `llvm::orc::*` / `PassManager` / - `SelectionDAG` 调用链,在原始火焰图里占比很大,但它属于一次性 plan 编译开销(LRU 缓存命中后不再编译), - 不计入热路径。剥离方式是过滤掉 `llvm::orc` / `*PassManager` / `SelectionDAG` / `MachineFunction` 等编译栈。 -2. **benchmark 主线程**:`bolt_hashaggr_j` 主线程几乎全是 plan 解析(`Parser::parse`、`parseTypeSignature`) - 和动态链接 setup(`elf_dynamic_do_Rela`、`do_lookup_x`),是 query 构建噪声,真正的算子执行在 - `CPUThreadPool0` 执行线程上。 - -剥离后,对执行线程上的真实算子热点做 leaf 归类对比(self time,已排除编译栈)。 - -### 11.3 执行线程热点对比 - -| leaf 热点 | JIT | no-JIT | 说明 | -|----------|----:|-------:|------| -| `[perf-*.map]`(JIT 生成码) | 9.7% | 9.7%(no-JIT 是其它匿名段) | JIT add/extract 生成码 | -| `clear_page_erms` | 8.9% | 9.7% | 内核清零新申请页 | -| `arrayGroupProbe` | 4.8% | 1.6% | hash 探测 | -| `AverageAggregateBase::addRawInput` | 3.2% | 5.6% | avg accumulator 更新 | -| `SumAggregateBase::addRawInput` | 3.2% | 2.4% | 子聚合更新 | -| `MinAggregate::addRawInput` | 2.4% | 4.0% | 子聚合更新 | -| `__memset_avx512` / `get_page_from_freelist` / `_int_malloc` | 合计 ~6% | 合计 ~5% | 新 group 内存分配 | -| `RowContainer::initializeRow` / `HashStringAllocator::clear` | ~3% | ~3% | 新 group 初始化 | -| `RowContainer::extractColumn` | 1.6% | 1.6% | 结果列抽取 | -| `VectorHasher::makeValueIdsFlatNoNulls` | 1.6% | 1.6% | key 编码 | - -关键观察: - -1. **两条路径的执行线程热点几乎重合**:top 热点都是 `clear_page_erms` + 内存分配 + `arrayGroupProbe` + - 各 accumulator 的 `addRawInput`,extract 相关符号(`extractColumn` / ROW child 写出)self time 都不到 2%。 -2. **extract 已经不是这个 case 的瓶颈**:第 10 章 ROW output raw descriptor 已把 extract helper 削掉, - 火焰图里 extract 已沉到噪声级别。partial avg extract 端到端略慢,**不是 extract kernel 导致的**。 -3. **真正的成本是 high-cardinality 的新 group 物化**:该 case groups=batches×batch_size(每行一个新组), - 每个新 group 都要 `clear_page` + `malloc` + `RowContainer::initializeRow`,这部分是 JIT/no-JIT 共有的固定成本, - 且占执行线程相当大比例。JIT 在这部分没有任何优化空间。 -4. **JIT 反而在 `arrayGroupProbe` 上采样更高(4.8% vs 1.6%)**:在「每行新组」的极端高基数下,JIT chunk 的 - group probe 调用方式相对 no-JIT 没有优势,叠加 add kernel 节省有限,导致端到端被新组物化稀释后呈现约 0.9x。 - -### 11.4 P1 结论 - -- partial_avg_extract 的 ROW output 瓶颈(第 9/10 章定位的 helper / dynamic_cast)已被 raw descriptor 解决, - 火焰图确认 extract self time 已 <2%。 -- 该 case 当前端到端约 0.9x 的剩余差距**不在 JIT 可优化范围内**:主导成本是 high-cardinality「每行新 group」 - 带来的 `clear_page` / 内存分配 / `RowContainer::initializeRow` / `arrayGroupProbe`,JIT 与 no-JIT 共享这部分开销, - JIT 能优化的 add/extract kernel 占比已被压得很低。 -- 因此 P1 的处理结论是:**partial_avg_extract 不再作为独立优化项继续深挖**。它代表的是「聚合计算占比极低、 - 新组物化占比极高」的负向场景,应通过**白名单/启发式**避免对这类 high-cardinality partial 聚合启用 JIT, - 而不是继续优化 extract 本身。 -- 真正还能换来 add kernel 收益的是 P2(flat/no-null add_dense 快路径)和 P3(下沉 per-row accumulator null clear), - 它们作用于计算占比高的 case,优先级高于继续打磨 partial_avg_extract。 - -附上此次优化后的benchmark report - -``` -$ ./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark -============================================================================ -[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s -============================================================================ -width4_sum_nojit 2.58ms 387.72 -width4_sum_jit 2.36ms 422.95 ----------------------------------------------------------------------------- -width4_avg_nojit 3.15ms 317.76 -width4_avg_jit 2.69ms 372.41 ----------------------------------------------------------------------------- -width4_min_nojit 2.46ms 407.31 -width4_min_jit 2.47ms 405.55 ----------------------------------------------------------------------------- -width4_count_nojit 2.40ms 416.38 -width4_count_jit 1.91ms 524.51 ----------------------------------------------------------------------------- -width4_merge_sum_nojit 3.70ms 270.06 -width4_merge_sum_jit 3.31ms 302.23 ----------------------------------------------------------------------------- -width4_merge_avg_nojit 4.53ms 220.94 -width4_merge_avg_jit 3.67ms 272.67 ----------------------------------------------------------------------------- -width4_merge_min_nojit 3.54ms 282.70 -width4_merge_min_jit 3.37ms 296.70 ----------------------------------------------------------------------------- -width4_merge_count_nojit 3.43ms 291.51 -width4_merge_count_jit 2.87ms 348.60 ----------------------------------------------------------------------------- -width8_sum_nojit 4.60ms 217.19 -width8_sum_jit 3.41ms 293.12 ----------------------------------------------------------------------------- -width8_avg_nojit 5.27ms 189.92 -width8_avg_jit 4.10ms 244.09 ----------------------------------------------------------------------------- -width8_min_nojit 3.76ms 266.31 -width8_min_jit 3.53ms 283.64 ----------------------------------------------------------------------------- -width8_count_nojit 4.31ms 231.77 -width8_count_jit 2.37ms 422.36 ----------------------------------------------------------------------------- -width8_merge_sum_nojit 6.17ms 162.07 -width8_merge_sum_jit 4.78ms 209.35 ----------------------------------------------------------------------------- -width8_merge_avg_nojit 7.26ms 137.75 -width8_merge_avg_jit 5.67ms 176.25 ----------------------------------------------------------------------------- -width8_merge_min_nojit 5.26ms 189.99 -width8_merge_min_jit 4.82ms 207.44 ----------------------------------------------------------------------------- -width8_merge_count_nojit 5.73ms 174.57 -width8_merge_count_jit 3.55ms 281.72 ----------------------------------------------------------------------------- -width16_sum_nojit 8.90ms 112.38 -width16_sum_jit 5.97ms 167.52 ----------------------------------------------------------------------------- -width16_avg_nojit 10.59ms 94.47 -width16_avg_jit 7.29ms 137.08 ----------------------------------------------------------------------------- -width16_min_nojit 7.66ms 130.62 -width16_min_jit 6.34ms 157.71 ----------------------------------------------------------------------------- -width16_count_nojit 7.70ms 129.92 -width16_count_jit 3.51ms 284.55 ----------------------------------------------------------------------------- -width16_merge_sum_nojit 11.17ms 89.55 -width16_merge_sum_jit 7.29ms 137.12 ----------------------------------------------------------------------------- -width16_merge_avg_nojit 14.74ms 67.86 -width16_merge_avg_jit 10.45ms 95.71 ----------------------------------------------------------------------------- -width16_merge_min_nojit 10.53ms 94.92 -width16_merge_min_jit 8.39ms 119.26 ----------------------------------------------------------------------------- -width16_merge_count_nojit 10.02ms 99.78 -width16_merge_count_jit 5.30ms 188.84 ----------------------------------------------------------------------------- -width32_sum_nojit 17.49ms 57.17 -width32_sum_jit 12.44ms 80.38 ----------------------------------------------------------------------------- -width32_avg_nojit 20.01ms 49.97 -width32_avg_jit 15.68ms 63.77 ----------------------------------------------------------------------------- -width32_min_nojit 15.48ms 64.62 -width32_min_jit 12.96ms 77.16 ----------------------------------------------------------------------------- -width32_count_nojit 16.00ms 62.52 -width32_count_jit 7.39ms 135.35 ----------------------------------------------------------------------------- -width32_merge_sum_nojit 22.50ms 44.45 -width32_merge_sum_jit 17.12ms 58.41 ----------------------------------------------------------------------------- -width32_merge_avg_nojit 27.85ms 35.90 -width32_merge_avg_jit 21.32ms 46.90 ----------------------------------------------------------------------------- -width32_merge_min_nojit 20.59ms 48.56 -width32_merge_min_jit 17.22ms 58.09 ----------------------------------------------------------------------------- -width32_merge_count_nojit 19.84ms 50.39 -width32_merge_count_jit 11.76ms 85.03 ----------------------------------------------------------------------------- -width8_decimal_sum_nojit 11.94ms 83.76 -width8_decimal_sum_jit 9.86ms 101.41 ----------------------------------------------------------------------------- -width8_decimal_avg_nojit 16.31ms 61.30 -width8_decimal_avg_jit 14.75ms 67.81 ----------------------------------------------------------------------------- -width8_double_min_nojit 4.97ms 201.05 -width8_double_min_jit 4.17ms 239.76 ----------------------------------------------------------------------------- -width8_double_max_nojit 4.29ms 233.10 -width8_double_max_jit 3.95ms 253.45 ----------------------------------------------------------------------------- -width8_high_card_partial_avg_extract_nojit 62.12ms 16.10 -width8_high_card_partial_avg_extract_jit 68.54ms 14.59 ----------------------------------------------------------------------------- -width8_high_card_partial_sum_extract_nojit 25.27ms 39.58 -width8_high_card_partial_sum_extract_jit 23.10ms 43.29 ----------------------------------------------------------------------------- -``` - -## 12. P2:flat/identity add_dense 快路径验证(已验证无收益,回退) - -### 12.1 优化假设 - -第 9.6.2 / 第 11 章曾把 add_dense 主路径的 `indices[row]` 间接寻址列为 P2 优化点: -flat(identity mapping)输入时 `indices[row] == row`,`loadDecodedValue` -(`bolt/jit/aggregation/HashAggrJit.cpp`)每行先 `index = indices[row]` 再 -`values[index]`,这一级 load 在 flat 下被认为是可省的冗余,预期省掉后能让取值-累加 -循环更利于向量化。 - -### 12.2 实现与验证 - -按方案 A 实现: - -1. `HashAggrJitDecodedInput` 增 `identityMapping` 标记字段; -2. `GroupingSet` 在准备 descriptor 时用 `DecodedVector::isIdentityMapping()` 填标记; -3. `loadDecodedValue` 在 IR 里据标记选择直接用 `row` 还是 `indices[row]`。 - -验证: - -- **功能**:编译通过;dump add_dense IR 确认 identity 分支正确生成、descriptor - trailing bool 的 offset 读取正确(`align 1`)。spark aggregate JIT 单测 - (`bolt_functions_spark_aggregates_test`,`--gtest_filter='*hashAggrJit*'`) - 3 passed / 2 failed,其中 2 个失败(`hashAggrJitMergeAndExtract`、 - `hashAggrJitAllNullGroup`)经 `git stash` 对比确认是**基线既有 bug,与 P2 无关**, - P2 未引入新回归。 -- **性能**:分别构建 baseline / P2 两个 benchmark binary,交替多轮对比 - `width8/16/32` 的 sum/avg/min jit 耗时。 - -### 12.3 实测结果 - -| 实现方式 | 相对基线 | 说明 | -|----------|----------|------| -| select 版(IR 内 `select` 选 index) | **慢约 3–6%** | `select` 仍无条件 load `indices[row]`,额外多算 flag load + select,净增指令 | -| branch 版(控制流跳过 `indices[row]` load) | **基本持平** | 多轮差异均在 ±1–2% 噪声内,无可测收益,且增加 IR 复杂度 | - -以 `width16_sum_jit` 三轮交替为例(branch 版):base 6.48 / 6.24 / 6.48ms, -P2 6.20 / 6.31 / 6.25ms——互有高低,落在噪声范围内。 - -### 12.4 结论 - -- **P2 在当前硬件 / 工作负载上没有可测收益,改动已全部回退到基线。** -- 根因:第 11 章把 `indices[row]` 当瓶颈的假设在实测中不成立。flat 输入下 - `indices` 是连续数组的顺序读,**硬件预取使其几乎零成本**,省掉它换不来收益; - select 版反而因多余指令小幅变慢。 -- 按「只做直接必要、不过度工程」的原则,无收益且增加复杂度的改动不保留。 -- 后续若再优化 add_dense 主路径,方向应转向真正的访存瓶颈(如 accumulator 在 - RowContainer 中的非连续布局),而非已被预取覆盖的 `indices[row]` 间接寻址。 -- P3(下沉 per-row accumulator null clear)的待确认正确性约束(新组创建与首次更新 - 是否同 batch)经评估不成立、争议较大,暂缓,不在本轮实施。 - ---- - -## 13. Decimal sum/avg add/merge 纯 IR 化 - -### 13.1 背景 - -decimal sum/avg 的 add/merge 主路径此前不是真正的 inline IR:每行通过 -`CreateCall(jit_HashAggrUpdate/MergeDecimal*)` 把 i128 加法 + 溢出检测转交 C++ -runtime helper(`jitHashAggrAddWithOverflow`)。即 IR 只完成 decode + 路由,真正 -算子在跨函数调用里执行——付出了 LLVM 的代价却没拿到 inline 红利。 - -### 13.2 改动 - -- 新增 `HashAggrJitCodegen::emitDecimalAddWithOverflow`:纯 IR 实现 i128 - `CreateAdd` + 溢出检测(`(a>0&&b>0&&r<0)||(a<0&&b<0&&r>=0)`,≤8 条 IR), - 溢出计数用 `posOverflow - negOverflow` 累加。 -- `DecimalSumOps` / `DecimalAvgOps` 的 init/add-raw/add-merge 全部改为纯 IR: - - init:直接 store sum/overflow/(count|isEmpty),替代 `jit_HashAggrInitDecimal*`。 - - add/merge:`emitDecimalAddWithOverflow` + IR 内 `++count` / `isEmpty &&=`。 - - state 字段访问用 `offsetof(JitDecimal*State, field)` 派生 offset,避免硬编码。 -- 删除不再被调用的 `jit_HashAggrInit/Update/MergeDecimal*` runtime helper 及其 - builtin 声明、`jitHashAggrAddWithOverflow`。 -- per-row 的跨函数调用从 N 次降为 0(add 主路径全部内联到循环体)。 - -### 13.3 性能(width8,bm_min_iters=50) - -| case | 改前 jit | 改后 jit | nojit(参考) | 改善 | -|------|----------|----------|----------------|------| -| width8_decimal_sum | 9.86ms | **9.01ms** | 11.79ms | ~9% | -| width8_decimal_avg | 14.75ms | **13.88ms** | 16.72ms | ~6% | - -- nojit 基线基本不变,说明提升来自 JIT 侧 add/merge 内联,而非环境波动。 -- 多轮测量 decimal_sum_jit 稳定在 8.1–9.0ms 区间(取决于机器负载),均优于改前。 -- 收益幅度小于「翻倍」的乐观预期:i128 算术本身有成本,且热循环还有 group - 寻址 / null 处理开销,per-row call 的消除只压缩了其中一部分。 - -### 13.4 正确性 - -- decimal 专项单测全部通过:`decimalSum` / `decimalGlobalSumOverflow` / - `decimalGroupBySumOverflow` / `decimalLargeCountRowsOverflow` / - `decimalSomeGroupsAllnullValues`(覆盖溢出、全 null 组等关键路径)。 -- extract 的 decimal 计算(依赖 `DecimalUtil` 精度判定、每组一次、非热路径) - 保留 runtime helper,不在本次范围。 - ---- - -## 14. partial avg extract 去掉运行时 fast/helper 分支 - -### 14.1 背景 - -`emitPartialAvgResult` 此前在 IR 里有 `hasRawRowOutput ? fast : helper` 的运行时 -分支(3 个 BasicBlock + 1 条件跳转):当 partial avg 输出 ROW 的 sum/count 子字段 -为 FLAT 时走直写 fast 路径,否则回退 `jit_HashAggrSetPartialAvgDouble` helper。 -但该分支判定的是**循环不变量**(`rowField0Values` 在整个 extract 调用内不变)。 - -### 14.2 改动 - -- 把 fast/helper 的选择从「运行时」前移到「extract 准入」: - `fillHashAggrJitPartialAvgOutput` 改为返回 bool,当 ROW 子字段非 FLAT - (dictionary/constant 包装)时返回 false;`runHashAggrJitExtractChunks` 据此 - 令 `canRunChunk=false`、回退非 JIT 并打 VLOG(`skipReason="partial avg row - fields are not flat"`)。 -- 这样保证进入 JIT 的 chunk 其 rowField0/1 必被填充,IR 里直接走纯 fast 路径。 -- `emitPartialAvgResult` 删除运行时分支与 3 个 BasicBlock;删除不再被调用的 - `jit_HashAggrSetPartialAvgDouble` runtime helper、builtin 声明及其 - `ComplexVector.h` include。 - -### 14.3 性能(bm_min_iters=50,基线=分支版,优化=纯 fast) - -| case | 基线 | 优化后(2 轮) | 变化 | -|------|------|--------------|------| -| width8_avg_jit | 4.26ms | 4.15 / 4.21ms | ~持平–3% | -| width16_avg_jit | 8.75ms | 7.55 / 7.58ms | ~14% | -| width8_merge_avg_jit | 6.22ms | 5.75 / 6.17ms | 波动,约 0–8% | -| width16_merge_avg_jit | 11.44ms | 10.85 / 10.70ms | ~5–6% | -| width8_high_card_partial_avg_extract_jit | 74.13ms | 70.00 / 68.84ms | ~6–7% | - -- 整体小幅改善或持平,无回归。改善幅度有限且部分用例有运行间波动——符合预期: - 被删的分支是循环不变量,LLVM LICM + 分支预测本就覆盖了大部分开销,去掉它主要 - 减少了 codegen 出的 BasicBlock 数与少量恒命中的比较/跳转。 -- 价值更多在**正确性与可维护性**:把「子字段非 FLAT」从 IR 兜底分支收敛为 plan - 阶段的显式准入回退,IR 不再生成永远走同一侧的运行时分叉。 - - -### 14.4 正确性 - -- partial avg / average 相关单测通过:`hashAggrJitPartialAvgExtractAccumulators` - (直接覆盖本次 fast 路径)、`avgDecimal` / `avgAllNulls` / - `rowBasedSpillDecimalAvg` / `hashAggrJitDecimalSumAndFloatingMinMax` / - `hashAggrJitSplitsContiguousSegments`。 -- 无新增回归(`hashAggrJitMergeAndExtract` / `hashAggrJitAllNullGroup` 仍 FAIL, - 系既有 P0 bug,见 todolist,与本次无关)。 - -### 14.5 当前性能 - - -``` -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_regex="(width8|width16)" -============================================================================ -[...]c/benchmarks/HashAggrJitBenchmark.cpp relative time/iter iters/s -============================================================================ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_merge_sum_nojit 6.36ms 157.19 -width8_merge_sum_jit 4.62ms 216.34 ----------------------------------------------------------------------------- -width8_merge_avg_nojit 7.57ms 132.09 -width8_merge_avg_jit 5.97ms 167.52 ----------------------------------------------------------------------------- -width8_merge_min_nojit 5.49ms 182.15 -width8_merge_min_jit 4.70ms 212.96 ----------------------------------------------------------------------------- -width8_merge_max_nojit 5.59ms 178.76 -width8_merge_max_jit 4.78ms 209.31 ----------------------------------------------------------------------------- -width8_merge_count_nojit 5.98ms 167.10 -width8_merge_count_jit 3.62ms 276.52 ----------------------------------------------------------------------------- -width16_merge_sum_nojit 11.78ms 84.87 -width16_merge_sum_jit 7.96ms 125.58 ----------------------------------------------------------------------------- -width16_merge_avg_nojit 15.00ms 66.69 -width16_merge_avg_jit 10.79ms 92.66 ----------------------------------------------------------------------------- -width16_merge_min_nojit 10.42ms 95.97 -width16_merge_min_jit 7.78ms 128.48 ----------------------------------------------------------------------------- -width16_merge_max_nojit 10.83ms 92.30 -width16_merge_max_jit 8.04ms 124.41 ----------------------------------------------------------------------------- -width16_merge_count_nojit 9.77ms 102.32 -width16_merge_count_jit 5.12ms 195.16 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_high_card_merge_sum_nojit 63.70ms 15.70 -width8_high_card_merge_sum_jit 52.54ms 19.03 ----------------------------------------------------------------------------- -width8_high_card_merge_avg_nojit 88.21ms 11.34 -width8_high_card_merge_avg_jit 83.18ms 12.02 ----------------------------------------------------------------------------- -width8_high_card_merge_min_nojit 61.33ms 16.30 -width8_high_card_merge_min_jit 53.34ms 18.75 ----------------------------------------------------------------------------- -width8_high_card_merge_max_nojit 63.09ms 15.85 -width8_high_card_merge_max_jit 52.99ms 18.87 ----------------------------------------------------------------------------- -width8_high_card_merge_count_nojit 60.97ms 16.40 -width8_high_card_merge_count_jit 56.22ms 17.79 ----------------------------------------------------------------------------- -width16_high_card_merge_sum_nojit 113.48ms 8.81 -width16_high_card_merge_sum_jit 90.50ms 11.05 ----------------------------------------------------------------------------- -width16_high_card_merge_avg_nojit 160.62ms 6.23 -width16_high_card_merge_avg_jit 146.38ms 6.83 ----------------------------------------------------------------------------- -width16_high_card_merge_min_nojit 116.27ms 8.60 -width16_high_card_merge_min_jit 92.11ms 10.86 ----------------------------------------------------------------------------- -width16_high_card_merge_max_nojit 113.06ms 8.84 -width16_high_card_merge_max_jit 91.86ms 10.89 ----------------------------------------------------------------------------- -width16_high_card_merge_count_nojit 100.63ms 9.94 -width16_high_card_merge_count_jit 89.21ms 11.21 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_merge_decimal_sum_nojit 19.57ms 51.09 -width8_merge_decimal_sum_jit 21.68ms 46.12 ----------------------------------------------------------------------------- -width8_merge_decimal_avg_nojit 19.50ms 51.29 -width8_merge_decimal_avg_jit 13.67ms 73.15 ----------------------------------------------------------------------------- -width16_merge_decimal_sum_nojit 39.46ms 25.34 -width16_merge_decimal_sum_jit 42.44ms 23.56 ----------------------------------------------------------------------------- -width16_merge_decimal_avg_nojit 40.01ms 24.99 -width16_merge_decimal_avg_jit 26.90ms 37.17 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -width8_merge_double_min_nojit 6.82ms 146.54 -width8_merge_double_min_jit 5.58ms 179.12 ----------------------------------------------------------------------------- -width8_merge_double_max_nojit 5.89ms 169.65 -width8_merge_double_max_jit 5.25ms 190.33 ----------------------------------------------------------------------------- -width16_merge_double_min_nojit 12.33ms 81.08 -width16_merge_double_min_jit 9.74ms 102.70 ----------------------------------------------------------------------------- -width16_merge_double_max_nojit 10.85ms 92.15 -width16_merge_double_max_jit 8.67ms 115.36 ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- -``` - -## 15. decimal_avg final extract JIT 补齐说明 - -### 15.1 背景 - -在本轮修复前,decimal avg 只有 partial extract 走 JIT helper,final extract 仍留在 non-JIT 路径: - -- planner/codegen 侧通过 `canCompileDecimalAvgExtract(..., partialOutput)` 仅允许 partial path; -- runtime 侧 `jit_HashAggrExtractFinalDecimalAvg` 是空 stub,仅用于 link 成功; -- 因此历史上部分 `decimal_avg` benchmark 结果,实际测到的是“JIT add/merge + non-JIT final extract”的混合路径。 - -这也是第 4 章里“decimal_avg 曾优于 decimal_sum”的一个背景因素:当时 decimal avg 没有承担 final decimal -rescale/divide 的 JIT extract 成本。 - -### 15.2 本轮实现 - -本轮已补齐 final decimal avg extract 的 JIT 支持,策略是**继续保持 helper 模式**,不把 Spark decimal avg 的 -divide / overflow / precision-rescale 逻辑直接展开成 LLVM IR。 - -具体改动: - -1. **放开 codegen**:`decimal avg` 的 extract 现在 partial / final 都允许编译; -2. **扩展 helper ABI**:avg extract helper 额外接收最终结果 decimal 的 `resultPrecision/resultScale`; -3. **实现 final runtime helper**:在 runtime 中镜像 non-JIT `computeAvg` 语义: - - `adjustSumForOverflow` - - `divideWithRoundUp` - - `rescaleWithRoundUp` - - short / long decimal 分类型写回 `FlatVector` -4. **benchmark 口径更新**:`HashAggrJitBenchmark` 中 `decimal_sum/decimal_avg` 统一按 `PartialFinal` 路径评测, - 避免继续把 decimal avg 记成“只测 partial + 非 JIT final extract”的旧口径。 - -### 15.3 功能验证 - -本轮未新增一组完整 benchmark 数据表,但已完成功能与构建验证: - -- 构建通过:`bolt_thrustjit`、`bolt_exec`、`bolt_functions_spark_aggregates_test` -- Average 相关测试通过: - - `AverageAggregationTest.avgAllNulls` - - `AverageAggregationTest.avgDecimal` - - `AverageAggregationTest.avgDecimalWithMultipleRowVectors` - - `AverageAggregationTest.rowBasedSpillDecimalAvg` - -说明 final decimal avg extract JIT 至少已经满足当前 Spark avg 语义下的基础正确性要求: - -- `count == 0` 输出 null; -- sum overflow 无法修正时输出 null; -- divide / rescale overflow 时输出 null; -- short decimal 与 long decimal 结果类型都可写回。 - -### 15.4 对阅读本报告的影响 - -1. **第 2/3/4 章中的早期 decimal_avg 结论需要加注理解**:这些历史结论产生时,final decimal avg extract 还未走 JIT。 -2. **后续若继续比较 decimal_avg 的 JIT/no-JIT 收益,应以当前 `PartialFinal` benchmark 口径为准**。 -3. **当前 decimal_avg benchmark 的收益解释更完整**:它现在同时覆盖 JIT add、JIT merge 和 JIT final extract, - 比之前更接近真实生产路径。 - -换句话说:从这一节之后,文档里关于 decimal avg 的性能讨论应默认理解为“**final extract JIT 已补齐**”的版本; -如果引用更早的数据,需要显式说明那是旧口径历史快照。 - -## 16. decimal sum merge 输入 row-field 快路径优化 - -### 16.1 背景 - -补齐 final extract 后,新增了 `*_merge_decimal_sum` / `*_merge_decimal_avg`(`PartialFinal` 口径)benchmark。 -其中 **`width16_merge_decimal_sum_jit` 反而比 non-JIT 慢**(约 42ms vs 39ms),而同口径的 decimal avg 却是 JIT 更快。 - -定位结论:瓶颈不在算术本身,而在 **final aggregation 的 merge 输入读取**。 - -- decimal sum merge 的中间结果是 `ROW(sum:decimal, isEmpty:bool)`; -- merge 热路径每行都要读 field0(sum) 与 field1(isEmpty),并判 field0 null(见 `DecimalSumOps.cpp`); -- 但 `fillHashAggrJitRowFieldInputs()` 当时**只为 avg 预填 row-field raw 指针**,decimal sum 没填; -- 于是 JIT 的 `loadDecodedRowField` / `isDecodedRowFieldNull` 全部掉到 helper slow path - (`jit_GetDecodedRowFieldI128 / I8 / IsNull`),而这些 helper **每次调用都重建一个 field 级 `DecodedVector`**; -- width16 下 slot 数翻倍,这个每行固定开销被线性放大,最终把 JIT 收益吃光。 - -decimal avg 之所以更快,正是因为它的 merge 输入早已走了 raw row-field 快路径(avg 的 `ROW(sum, count)`)。 - -### 16.2 改动 - -分两步把 decimal sum 的 merge 输入快路径补齐。 - -**第一步:field0(sum) raw fast path** - -扩展 `fillHashAggrJitRowFieldInputs()`(`bolt/exec/GroupingSet.cpp`),从“仅 avg”扩展到“avg + decimal sum”: - -- decimal sum 的 field0(sum, int128) 填充 `rowField0Values` / `rowField0Nulls`; -- 这样 JIT 读取 sum 与判 sum null 直接命中 `loadDecodedRowField` 的 raw fast path,不再每行重建 `DecodedVector`。 - -**第二步:field1(isEmpty) bit-packed bool fast path** - -field1 的 `isEmpty` 是 **bit-packed bool**,没有按字节排布的 scalar 指针,不能复用普通 fast path,因此单独新增一条按位读取的快路径: - -- 新增 `HashAggrJitCodegen::loadDecodedRowFieldBool()`(`bolt/jit/aggregation/HashAggrJit.cpp`): - 把 `rowField1Values` 视为 i64 word 数组,`word = index>>6`、`bit = index&63`,直接 `(words[word] >> bit) & 1`; - raw 指针为空时回退 `jit_GetDecodedRowFieldI8` helper; -- `DecimalSumOps.cpp` 的 merge 改用 `loadDecodedRowFieldBool` 读 `isEmpty`; -- `fillHashAggrJitRowFieldInputs()` 为 decimal sum 填 field1 的 bit 缓冲区指针, - 通过 `valuesAsVoid()` 取(注意:`FlatVector::rawValues()` 对 bool 会抛 `UNSUPPORTED`,必须用 `valuesAsVoid()`)。 - -涉及文件: - -- `bolt/exec/GroupingSet.cpp`:`fillHashAggrJitRowFieldInputs` 扩展支持 decimal sum,并填 field0/field1 指针; -- `bolt/jit/aggregation/HashAggrJit.{h,cpp}`:新增 `loadDecodedRowFieldBool`; -- `bolt/jit/aggregation/ops/DecimalSumOps.cpp`:merge 改用 bool 快路径读 `isEmpty`。 - -### 16.3 性能(PartialFinal 口径,bm_min_iters 默认) - -| case | nojit | jit(初始) | jit(+field0) | jit(+bool) | -|---|---|---|---|---| -| width8_merge_decimal_sum | 19.55ms | 慢于 nojit | 17.24ms | **16.19ms** | -| width16_merge_decimal_sum | 39.44ms | ~42ms(慢于 nojit) | 33.27ms | **31.30ms** | -| width8_merge_decimal_avg | 20.68ms | —(本就更快) | — | **14.55ms** | -| width16_merge_decimal_avg | 40.09ms | —(本就更快) | — | **28.12ms** | - -要点: - -- `width16_merge_decimal_sum_jit`:42ms(慢于 nojit)→ 33.3ms(field0 快路径)→ **31.3ms**(再加 bool 快路径),已稳定反超 nojit 的 39ms; -- decimal avg 不受负面影响,仍保持 JIT 更快。 - -### 16.4 正确性 - -- 两步快路径读到的都是与原 helper 完全相同的底层数据,仅省掉了每行的 `DecodedVector` 重建,无语义变化; -- 实现期间踩到一个坑:最初用 `FlatVector::rawValues()` 取 bit 缓冲区,运行时抛 - `BoltUserError: rawValues() for bool is not supported`,改用 `valuesAsVoid()` 后正常; -- benchmark 真实 query 在 JIT / non-JIT 双路径下均正常执行、无崩溃、无异常。 diff --git a/doc/hashaggr-jit-code-review.md b/doc/hashaggr-jit-code-review.md deleted file mode 100644 index 8688e6fdf..000000000 --- a/doc/hashaggr-jit-code-review.md +++ /dev/null @@ -1,133 +0,0 @@ -# Hash Aggregation JIT 代码 Review 清单 - -> 对比范围:`d4e69030bfbe1d27eb31e6ad49027833bfce2c8e..HEAD`(hash aggr jit 支持代码,~7956 行 / 40 文件) -> 用途:逐条优化的工作清单。**本文档仅记录问题与建议,不含代码修改。** -> 关注维度:① 代码坏味道 ② JIT 与非 JIT 关键数据结构(raw input / intermediate / group / result)一致性 ③ 框架与聚合函数耦合残留 ④ 数据结构冗余 - ---- - -## 0. 总体结论 - -- **架构方向正确**:`HashAggrJitOps` 回调结构体已把各聚合语义下沉到 `ops/`,框架三大骨架 `genInitIR/genAddDenseIR/genExtractIR` 没有 `switch(kind)` 大分支。 -- **数据布局基准一致**:accumulator 起始 offset、null byte/mask 全部来自框架 `Aggregate::createHashAggrJitSlot`(`Aggregate.cpp:335`),JIT 不硬编码;内部字段用 `offsetof + static_assert(is_standard_layout)` 锁定。6 对聚合实现的字段顺序、ROW 子列顺序、null 标记、溢出语义**当前全部一致**。 -- **主要待优化点**: - 1. decimal 的 IR 生成(add-with-overflow / extract)泄漏在框架 `HashAggrJitCodegen` 中(耦合)。 - 2. 三大 gen 函数与 Sum/MinMax ops 存在大量样板重复(坏味道)。 - 3. JIT decimal state 与非 JIT accumulator 是「镜像复制」而非源码共享,缺跨编译单元交叉断言(一致性风险 + 冗余)。 - 4. descriptor 的 decimal 专属字段被所有 slot 冗余携带(数据冗余)。 - -优先级建议:先做 **C1(解耦 decimal)+ S1/S2(消骨架与 Sum/MinMax 重复)+ D1/D2(消 decimal 双定义与死字段)**,收益最大。 - ---- - -## 1. 框架与聚合函数耦合残留(关注点③) - -| # | 位置 | 问题 | 建议 | 严重度 | -|---|---|---|---|---| -| **C1** | `HashAggrJit.cpp:797-841` `emitDecimalSumExtract`/`emitDecimalAvgExtract`;`843-874` `emitDecimalAddWithOverflow`;头文件 `HashAggrJit.h:266-291` | decimal 专属 IR 生成(i128 累加+overflow 进位、调 decimal runtime、读 precision/scale/auxPrecision/auxScale、long/short decimal 判断)是**框架类 `HashAggrJitCodegen` 的成员**;ops 只是转手调用。decimal 知识泄漏进框架。 | 下沉到 `DecimalSumOps.cpp`/`DecimalAvgOps.cpp` 内 `static` 辅助函数,只依赖通用原语(loadValue/storeValue/builder/module)。框架类不应有任何带 "Decimal" 的方法。 | 高 ✅ **已完成** | -| **C2** | `HashAggrJit.cpp:118-148` `ensureBuiltinDeclarations` | 框架构造函数无条件声明 4 个 decimal extract runtime 签名,即使 chunk 内无 decimal。 | 给 `HashAggrJitOps` 增加可选 `declareRuntime(llvm::Module&)` 回调,由各 decimal ops 自行声明;框架只声明通用 `jit_HashAggrResizeVector`。 | 中 | -| **C3** | `HashAggrJit.cpp:89-102` `kHashAggrRuntimeLinkAnchors` | 框架 TU 的链接锚点引用了 decimal 专属符号 `jit_HashAggrExtractFinalDecimalSum`,使框架编译期强依赖 decimal runtime。 | 框架锚点改用通用 runtime 符号;decimal runtime 锚点由 decimal ops TU 自持。 | 中 | -| **C4** | `HashAggrJit.cpp:1002`(配套 `GroupingSet.cpp:1177`) | 框架骨架 `if (checkInputNulls && !slot.desc.countStar)` 直接判 count 专属 flag。 | 用通用语义字段(如 `consumesInput`/`hasScalarInput`)替代 `countStar` 在框架层的判断;count 语义判定保留在 `CountOps`。 | 低 | -| **C5** | `HashAggrJit.cpp:1115-1129`(及被注释的 `signature()` `1197-1210`) | chunk 名拼接直接读 `countStar/mergeInput/decimal/kind` 等具体 flag。属较合理的元数据消费,但仍耦合 flag 名。 | 可提供 `ops->signatureSuffix(slot)` 回调由算子补充自身特征。 | 低 | - -> 编排层(`GroupingSet.cpp` / `Aggregate.cpp` 的 slot/descriptor 构建)已通过虚函数 `supportsHashAggrJit`/`createHashAggrJitDescriptor` 干净解耦,无按 kind 写死映射。 - ---- - -## 2. JIT 与非 JIT 数据结构一致性(关注点②) - -### 2.1 逐对结论(当前均一致) - -| 聚合 | JIT | 非 JIT | 结论 | -|---|---|---|---| -| SUM | `SumOps.cpp` 单标量写 `slot.offset` | `SumAggregateBase.h` 裸标量 | ✅ 一致 | -| AVG | `AvgOps.cpp:16-23` `{double sum; int64 count}` + offsetof | `SumCount` `AverageAggregateBase.h:81-84` | ✅ 一致(ROW 子列 `{sum,count}`) | -| DECIMAL SUM | `JitDecimalSumState{sum,overflow,isEmpty}` + offsetof | `DecimalSum` `DecimalSumAggregate.h:37-48` | ✅ 一致(ROW `{sum,isEmpty}`,溢出哨兵语义一致) | -| DECIMAL AVG | `JitDecimalAvgState{sum,count,overflow}` + offsetof | `LongDecimalWithOverflowState` `DecimalAggregate.h:45-82` | ✅ 一致(ROW `{sum,count}`) | -| COUNT | `CountOps.cpp` 单 i64,结果永不 null | `CountAggregate.cpp` 裸 int64 | ✅ 一致 | -| MIN/MAX | `MinMaxOps.cpp` 单标量,null 表「空」 | `MinMaxAggregates.cpp` 裸标量 | ✅ 一致(Int128/Bool 回退非 JIT) | - -### 2.2 一致性风险点(靠人工同步维持,需重点盯) - -| # | 位置 | 问题 | 风险 | -|---|---|---|---| -| **R1** ✅ **已完成** | `AvgOps.cpp:16-23` vs `AverageAggregateBase.h:81-84` | ~~`AvgAccumulatorLayout` 与 `SumCount` 是两份独立定义,跨编译单元无法交叉 `static_assert(sizeof/offsetof==)`。改一处忘改另一处会静默写错 count 偏移。~~ 已抽出零依赖头 `SumCount.h` 作为唯一权威定义;JIT 端 `using AvgAccumulatorLayout = functions::aggregate::SumCount`,offset 由权威结构 `offsetof` 派生,自动同步,镜像漂移消除。 | 中 | -| **R2** ✅ **已完成** | `HashAggrJitDecimalState.h:16-26` vs `DecimalSumAggregate.h:37-48` / `DecimalAggregate.h:45-82` | ~~同为镜像复制。注意 DecimalSum 是 `{sum,overflow,isEmpty}`、DecimalAvg 是 `{sum,count,overflow}`,**overflow 字段位置不同**;且 `LongDecimalWithOverflowState::serialize()` 顺序(count,overflow,sum)又与内存布局不同,极易混淆。JIT 只读内存不走 serialize,当前正确但无编译期交叉校验。~~ 已抽出零依赖头 `DecimalAccumulatorLayout.h`(`DecimalSumAccumulatorLayout`/`LongDecimalWithOverflowLayout` 两个 POD 布局基类);`DecimalSum`/`LongDecimalWithOverflowState` 继承之(只加方法、不加数据成员,保持 standard-layout),JIT 端 `using JitDecimalSumState/JitDecimalAvgState` 别名同一布局基类。布局自动同步,4 处 `static_assert(is_standard_layout_v)` 兜底防派生类加字段。 | 中 | -| **R3** | `HashAggrDecimalRuntime.cpp:29-110` | 4 个 runtime helper 逐行复制了非 JIT 的 `computeFinalValue`/`computeAvg`/`adjustSumForOverflow`/`rescaleWithRoundUp` 及常量(`kCountPrecision=20` 等)。属行为复制,非 JIT 改 decimal 语义需同步两处。 | 中 | -| **R4** | `AvgOps.cpp:114-129` | 全 null group partial extract 输出 `(0,0)` 且 top-level 非 null,对齐的是 sparksql 重载版 `AverageAggregate.cpp:112-132`(非 lib 基类版)。需确认 JIT 仅用于 sparksql 路径。 | 低 | -| **R5** | `AvgOps.cpp:42-70`/`132` | avg 的最终 null 实际靠 `count==0` 判定,accumulator null byte 对 avg 不参与结果 null,存在但冗余,易误读。 | 低 | -| **R6** | `MinMaxOps.cpp:74-79`/`SumOps.cpp:51-55`/`AvgOps.cpp:100-103` | `canCompile*Extract` 用 accumulatorKind 白名单回退非 JIT,正确但靠人工维护;新增类型忘更新可能误走 JIT。 | 低 | - ---- - -## 3. 代码坏味道(关注点①) - -### 3.1 高严重度 - -| # | 位置 | 问题 | 建议 | -|---|---|---|---| -| **S1** | `HashAggrJit.cpp:906-925`/`953-978`/`1047-1067` | 三个 gen 函数的 LLVM 函数原型构造、entry/loop/end BB、`numRows<=0` guard、行循环 PHI、groupAddr/group 三段骨架几乎逐字重复 3 份。 | 抽取 `beginGroupLoop()/endGroupLoop()` 公共辅助返回 `{Function*, loop/end BB, PHI* row, group}`;`i8PtrTy/i32Ty/voidTy` 收进 `JitTypes` 缓存。 | -| **S2** | `SumOps.cpp:14-27`/`57-69` vs `MinMaxOps.cpp:14-25`/`81-93`(Avg init 前半段同) | Sum 与 MinMax 的 init(setNull+存0)与 extract(load+isNull+write)逐行相同。 | 抽 `compileZeroInitNullableAccumulator()` 与 `compileSimpleNullableExtract()` 复用。 | -| **S3** ✅ **已完成** | `HashAggrJitDecimalState.h:16-26` | ~~JIT decimal state 与非 JIT accumulator 重复定义,靠 `static_assert(standard_layout)` 无法保证与原结构字段顺序/对齐一致。~~ 已抽出零依赖布局基类(`DecimalAccumulatorLayout.h`),JIT 端用 `using` 别名复用,非 JIT 结构继承同一基类,布局单一权威来源。 | `using` 复用原结构,或加 `static_assert(sizeof/offsetof==)` 钉死并注释「布局必须与 X 同步」。 | - -### 3.2 中严重度 - -| # | 位置 | 问题 | 建议 | -|---|---|---|---| -| **S4** | `HashAggrJitTypes.h:141` + `HashAggrJit.cpp:1197-1210` | 被注释掉的死代码 `signature()`;``(`:15`) 仅服务这段死代码;逻辑与 chunk 名拼接重合。 | 删死代码或与 chunk 名拼接合并为真函数,移除多余 include。 | -| **S5** | `HashAggrJit.h:29-35` AddFn 的 `nextBlock` 参数 | 6/8 ops 实现未用(匿名 `BasicBlock*`),仅 decimal 用;框架调用后又无条件 `CreateBr(nextBlock)`(`:1028`),控制流职责模糊。 | 将「分支到 nextBlock」职责收归框架,decimal overflow 分流改用局部 if/PHI,从签名删除 `nextBlock`。 | -| **S6** | `SumOps.cpp:51`/`AvgOps.cpp:100`/`MinMaxOps.cpp:74`/`CountOps.cpp:69`/`DecimalAvgOps.cpp:123` | canExtract 第二参数有的匿名有的命名 `partialOutput` 却不用,风格不一。 | 统一:不用就一律匿名,typedef 处注释语义。 | -| **S7** | `HashAggrJit.cpp:692-716` `ScalarOutputAdapterCodegen::write` | `kind` 不支持时静默 no-op(既不写也不报错),与 `RowOutputAdapterCodegen::writeField`(`:762` 用 BOLT_CHECK) 不一致。 | 补 `else BOLT_UNSUPPORTED(...)`。 | -| **S8** | `HashAggrJit.cpp:1115-1129` | 超长 ostringstream 拼接函数名,单字符 flag(`s/x`,`g/r`,`d/n`)无注释,可维护性低。 | 抽 `appendSlotSignature(out, slot)` 并加注释。 | -| **S9** | gen 用 `return false`(`:937-939,1019-1026,1092-1094`);适配器用 `BOLT_UNSUPPORTED`(`:583,590,610,723,759`);`writeField` 用 `BOLT_CHECK`(`:767`) | 同模块对「不支持/非法状态」三种处理方式混用。 | 明确契约:可降级→bool,编程错误/不变量破坏→BOLT_CHECK,头注释写清。 | - -### 3.3 低严重度 - -| # | 位置 | 问题 | 建议 | -|---|---|---|---| -| **S10** | `HashAggrJit.cpp:995-1001`/`1085-1091` | 循环内每 slot `make_unique` 适配器(轻量值类型无需堆分配+虚表)。 | 用 `std::variant<...>` 或栈对象+基类引用。 | -| **S11** | `HashAggrJit.cpp:988`/`1077` | `for (auto i=0; iid` 字符串 | 算子身份 enum + 字符串双重标识;`kind` 实际仅 MinMax 区分用到(`MinMaxOps.cpp:48,61`)+ chunk 命名,Sum/Avg/Count 的 kind 与 ops 冗余。 | 只留 `ops` 指针给 MinMax 加 `isMin` 标志或拆两个 ops;或去 id 字符串改 kind 派生名,二选一。 | -| **D5** | `HashAggrJitTypes.h:128`(`decimal`)、`:129-130`(`inputShape/outputShape`) | `decimal` 与「ops 是否 Decimal*」一一对应;shape 与适配器选择(`HashAggrJit.cpp:888-894`)一一对应,属派生型冗余。 | ops 表暴露 `isDecimal`/`defaultShape` 后可去字段;否则注释「与 ops 绑定,prepare 填充」。 | -| **D6** | `HashAggrJitTypes.h:56`/`68` union 两变体各存 `void* vector` | scalar 与 row 输出变体都放一个语义相同的顶层 `vector` 字段。 | 把 `vector` 提到 union 外公共头部。 | -| **D7** | `AvgOps.cpp:16-19`(局部)vs `HashAggrJitDecimalState.h`(共享头);Sum/MinMax 裸标量是隐式约定无 struct | accumulator layout 存放位置与表达方式不一致。 | 统一:都用具名 struct+offsetof 或都注释化。 | - -### 4.3 低严重度 - -| # | 位置 | 问题 | 建议 | -|---|---|---|---| -| **D8** | `HashAggrJitTypes.h:100-109`;`HashAggrJit.cpp:152-154`/`176-189` | `Bool` 在 llvmType 等价 Int8,仅少数处特判,枚举语义重叠。 | 评估改 Int8+`isBool` 标志,或注释说明差异点。 | -| **D9** | `HashAggrJitTypes.h:144-152` | slot 与 desc **未**重复携带 offset/null(已确认);但 `desc` 按值内嵌使多 slot 拷贝整份 descriptor(与 D1 叠加)。 | 若 descriptor 可共享,slot 改持 `const HashAggrJitDescriptor*` 减少拷贝与死字段复制。 | - ---- - -## 5. 建议的优化顺序 - -1. **C1** ✅ **已完成**:decimal IR 生成已下沉到 ops(`emitDecimalAddWithOverflow`/`emitDecimalSumExtract` 定义于 `DecimalSumOps.cpp`,`emitDecimalAvgExtract` 定义于 `DecimalAvgOps.cpp`,声明在新增的 `ops/DecimalOps.h`;框架类 `HashAggrJitCodegen` 不再持有任何 "Decimal" 方法)。 -2. **S1 + S2**:消除三大 gen 骨架重复、合并 Sum/MinMax init/extract。 -3. **S3/D2 + R2** ✅ **已完成**:decimal state 双定义已改为继承共享 POD 布局基类(`DecimalAccumulatorLayout.h`)+ JIT `using` 别名复用,布局单一权威来源(同时降一致性风险与冗余)。 -4. **D1/D9**:descriptor decimal 死字段拆出 + slot 改持指针。 -5. **C2/C3**:decimal runtime 声明与链接锚点下沉。 -6. **R1** ✅ **已完成** / **R3**:R1(Avg layout)已通过抽出 `SumCount.h` + JIT `using` 复用消除镜像;R3(decimal runtime 逻辑复制)待加交叉校验/同步注释。 -7. 其余坏味道(S4–S15)、冗余(D3–D8)按批次清理。 diff --git a/hashaggr_jit_refactor_plan.md b/hashaggr_jit_refactor_plan.md deleted file mode 100644 index 63aab908b..000000000 --- a/hashaggr_jit_refactor_plan.md +++ /dev/null @@ -1,1571 +0,0 @@ -# Bolt Hash Aggregation JIT 框架重构落地方案 - -> 目标读者:AI/工程师,按本文档执行即可完成 `hash_aggr_jit` 分支当前框架的重构落地。 -> 适用版本:`dp/bolt @ hash_aggr_jit` 分支(基于 commit `9a65fd2` 之后)。 -> 本方案只描述 **JIT 框架层**重构,不涉及非 JIT codepath。 - ---- - -## 0. TL;DR - -把当前 `HashAggrJitDecodedInput / HashAggrJitOutput / per-aggregate codegen` 这一套耦合实现,重构为 **三层正交架构**: - -``` -┌──────────────────┐ IRRow ┌─────────────┐ IRRow ┌────────────────┐ -│ InputAdapter │ ─────────▶ │ GroupOps │ ─────────▶ │ OutputAdapter │ -│ (Vector → IR) │ │ (IR ↔ Group)│ │ (IR → Vector) │ -└──────────────────┘ └─────────────┘ └────────────────┘ -``` - -三层之间的唯一传输格式是 **LLVM First-Class Aggregate** 类型: - -``` -IRRow_t = llvm::StructType::get(value_type, i1_ty) - = { T, i1 } // T 由 aggregate 自己决定,可以是复合类型 -``` - -`is_null` 永远在第二个字段,框架统一处理;`value_type` 内部结构对框架透明。 - -## 1. 当前问题(背景) - -落地前必须理解这些已存在的痛点,重构必须**逐项消除**。 - -### 1.1 数据结构无通用性 - -```cpp -// HashAggrJit.h —— 反例 -struct HashAggrJitDecodedInput { - const void* data; - const uint64_t* nulls; - // ... 写死了若干字段,新增 aggregate 类型就要扩字段 -}; -struct HashAggrJitOutput { /* 同上 */ }; -``` - -- **病症**:每加一种聚合 / 一种 vector encoding,就要改这两个结构 + 改 IR 的 hardcoded byte offset。 -- **影响**:ABI 双向耦合(C++ struct ↔ IR offset),任何字段重排都是坑。 - -### 1.2 Vector ↔ IR 与 IR ↔ Group 两段逻辑混在一起 - -每个 `XxxAggregate::codegenAddDense / codegenExtract` 同时做: -1. 从输入 vector decode 出值 -2. 在 IR 里做累加 / 比较 -3. 把结果按 group 内 memory layout 写回 - -→ 三件事完全不正交,维护成本爆炸;且每个 aggregate 都要重新写 vector decoding 逻辑。 - -### 1.3 复合 value 类型(avg)特殊化 - -avg intermediate 当前在多处直接写成三元组 `{f64 sum, i64 count, i1 is_null}`,把 null 处理跟 value 内部结构耦合在一起,框架 helper 无法复用。 - ---- - -## 2. 目标架构 - -### 2.1 核心抽象:`IRRow` - -**契约**: - -```cpp -// 框架级 invariant —— 所有 aggregate 共用 -IRRow_t(value_type) := llvm::StructType::get(value_type, i1Ty) -// ^^^^^^^^^^ ^^^^^ -// field 0 field 1 (is_null) -``` - -**关键决策(已与作者确认)**:当 `value_type` 本身是复合类型(如 avg 的 `{double sum, i64 count}`),**采用嵌套** `{{double, i64}, i1}`,不采用平铺 `{double, i64, i1}`。 - -理由(简版,详细对比见 §6): -- 嵌套保持 `IRRow = {T, i1}` 不变量,框架 helper 完全通用; -- 平铺让 framework 必须知道 T 内部 field 数量,破坏抽象; -- 二者 memory layout 完全相同(24B),lowering 后寄存器分配完全一致,**性能零差异**; -- 未来 stddev / HLL / array_agg 等复合 value 聚合都能复用同一套框架。 - -### 2.2 三层职责 - -| 层 | 输入 | 输出 | 不该做 | -|----|------|------|--------| -| **InputAdapter** | `BaseVector*` + row index (IR) | `IRRow`(in register) | 不感知 group memory | -| **GroupOps** | `IRRow` + `group ptr`(IR) | 写回 group / 产出新 `IRRow` | 不感知 vector encoding | -| **OutputAdapter** | `IRRow` + `BaseVector*` + row index | 写回 vector | 不感知 group memory | - -每一层都对其它两层透明 —— 通过 IRRow 的标准接口(见 §3)通信。 - -### 2.3 调用链对应关系 - -| 算子方法 | 三层调用链 | -|----------|-----------| -| `addRawInput` | `InputAdapter::read(rawVec, i)` → `GroupOps::accumulate(group, IRRow)` | -| `addIntermediateResults` | `InputAdapter::read(intVec, i)` → `GroupOps::merge(group, IRRow)` | -| `extractIntermediateResults` | `GroupOps::loadIntermediate(group)` → `OutputAdapter::write(intVec, i, IRRow)` | -| `extractResults` | `GroupOps::finalize(group)` → `OutputAdapter::write(finalVec, i, IRRow)` | -| `initGroup` | `GroupOps::init(group)` | - ---- - -## 3. 框架层 API(必须实现) - -新增文件:`velox/exec/jit/IRRow.h`、`velox/exec/jit/InputAdapter.h`、`velox/exec/jit/GroupOps.h`、`velox/exec/jit/OutputAdapter.h`(路径按 bolt 现有 jit 目录调整)。 - -### 3.1 `IRRow` —— 唯一传输格式 - -```cpp -class IRRow { - public: - // 类型构造:value_type 由 aggregate 决定 - static llvm::StructType* getType(llvm::IRBuilder<>& b, llvm::Type* value_type) { - return llvm::StructType::get(value_type, b.getInt1Ty()); - } - - // ---- 读 ---- - static llvm::Value* getValue(llvm::IRBuilder<>& b, llvm::Value* row) { - return b.CreateExtractValue(row, {0}); - } - static llvm::Value* getIsNull(llvm::IRBuilder<>& b, llvm::Value* row) { - return b.CreateExtractValue(row, {1}); - } - - // ---- 写 ---- - static llvm::Value* pack(llvm::IRBuilder<>& b, - llvm::Value* val, - llvm::Value* is_null) { - auto* ty = llvm::StructType::get(val->getType(), is_null->getType()); - auto* tmp = b.CreateInsertValue(llvm::UndefValue::get(ty), val, {0}); - return b.CreateInsertValue(tmp, is_null, {1}); - } - - static llvm::Value* withValue(llvm::IRBuilder<>& b, - llvm::Value* row, - llvm::Value* val) { - return b.CreateInsertValue(row, val, {0}); - } - static llvm::Value* withIsNull(llvm::IRBuilder<>& b, - llvm::Value* row, - llvm::Value* is_null) { - return b.CreateInsertValue(row, is_null, {1}); - } - - // ---- 复合 value 的二级访问:仅在 GroupOps 内部使用 ---- - static llvm::Value* getValueField(llvm::IRBuilder<>& b, - llvm::Value* row, - unsigned idx) { - return b.CreateExtractValue(row, {0, idx}); // 注意:嵌套 GEP - } -}; -``` - -**强约束**:除了 `IRRow` 这套 helper 之外,**任何代码不得**直接对 IRRow struct 做 `extractvalue` / `insertvalue` —— 一旦发现就是抽象泄漏。 - -### 3.2 `InputAdapter` —— 规范化输入描述 → IRRow - -```cpp -class InputAdapter { - public: - virtual ~InputAdapter() = default; - - // 在 codegen 阶段调用,返回 IR 类型(必须等于对应 aggregate 的 IRRow_t) - virtual llvm::StructType* irRowType(llvm::IRBuilder<>& b) const = 0; - - // 在 IRBuilder 当前位置生成读取代码:从 vector + index 读出一个 IRRow - virtual llvm::Value* read(llvm::IRBuilder<>& b, - llvm::Value* vector_ctx, - llvm::Value* row_idx) const = 0; -}; -``` - -**关键修正**:这里的 `InputAdapter` **不能**按原始 vector encoding(flat / constant / dictionary)拆成不同 JIT 实现。 - -原因是:同一个 compiled chunk 会反复运行在不同 batch 上,而 batch 的原始 encoding 可以变化。若按原始 -encoding 生成不同 IR,则 JIT module cache key 会被 batch 形态污染,代码无法收敛,甚至会退化成“按批次特化并反复编译”。 - -因此,正确边界应当是: - -```text -原始 Vector(flat/constant/dictionary/...) - │ - ▼ -GroupingSet / DecodedVector 先做批次级规范化 - │ - ▼ -Canonical decoded descriptor - { values, indices, nulls, decodedVector, rowField*... } - │ - ▼ -InputAdapter 只针对“规范化后的运行时描述”生成 IR -``` - -也就是说: - -- flat / constant / dictionary 的差异,应该在 **JIT 之前** 被 `DecodedVector` + runtime descriptor 吸收; -- JIT 内的 `InputAdapter` 面向的是**稳定 ABI**,而不是每个 batch 的原始 encoding; -- 这样生成出来的 IR 才能在不同 batch 上复用并保持收敛。 - -**实现一览**(最少需要这些 adapter,它们对应“规范化后的输入形态”,而不是原始 encoding): - -| Adapter | 处理的规范化形态 | 关键 IR 行为 | -|---------|------------------|--------------| -| `DecodedScalarInputAdapter` | 标量输入:`values + indices + nulls` | `index = indices[row]`;`gep + load` 数据;`bit test` top-level nulls;pack 成 IRRow | -| `DecodedRowInputAdapter` | ROW intermediate:`rowField* + decodedVector(fallback)` | 优先走 field raw pointers/nulls;必要时回退 row-field helper;在 IR 里构造嵌套 IRRow | -| `CountStarInputAdapter` | 无实参输入 | 直接产出固定非空 IRRow / 或由 GroupOps 特判 | - -> 每个 adapter **只负责自己**对应的“规范化输入 contract”到 IRRow 的转换,不涉及任何聚合语义。 -> -> 特别注意:`DecodedScalarInputAdapter` 生成的 IR 在 flat / constant / dictionary batch 上应完全相同;不同 batch -> 只通过 `indices/nulls/values` 的运行时内容体现差异,而不改变 IR 形状。 - -### 3.3 `GroupOps` —— IRRow ↔ Group - -**关键修正**:`GroupOps` 在 bolt 当前实现里,**不应该**被设计成“拥有 group layout / group size / group align”的抽象。 - -当前事实是: - -- group memory 由 `RowContainer + AggregateInfo + accumulator layout` 共同决定; -- JIT 侧真正拿到的是 `group ptr + HashAggrJitSlot`; -- 访问状态依赖 `slot.offset / slot.nullByte / slot.nullMask`,以及像 `JitAvgState` / `JitDecimal*State` - 这样的现有 state struct offset; -- 当前 `HashAggrJitOps` 也是围绕这个 contract 工作,而不是自己管理 group allocation。 - -因此,更贴近 bolt 现状的 `GroupOps` 应该是:**“在既有 slot/layout 之上生成 group state 读写 IR 的薄层 policy”**,而不是一个重新定义 group 存储协议的 owner。 - -```cpp -class GroupOps { - public: - virtual ~GroupOps() = default; - - // 该聚合的 intermediate value type(不含 is_null,框架自动包一层) - virtual llvm::Type* intermediateValueType(llvm::IRBuilder<>& b) const = 0; - virtual llvm::Type* finalValueType(llvm::IRBuilder<>& b) const = 0; - - // ---- codegen hooks ---- - // slot 提供当前 aggregate 在 group row 中的 offset/null-bit 等元数据。 - virtual void init(HashAggrJitCodegen& codegen, - llvm::Value* group, - const HashAggrJitSlot& slot) const = 0; - - // 用 raw input 的 IRRow 累加进 group(对应当前 addRawInput)。 - virtual void accumulate(HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* input_irrow, - const HashAggrJitSlot& slot, - llvm::BasicBlock* nextBlock) const = 0; - - // 用 partial / intermediate 的 IRRow 合并进 group(对应当前 addIntermediateResults)。 - virtual void merge(HashAggrJitCodegen& codegen, - llvm::Value* group, - llvm::Value* intermediate_irrow, - const HashAggrJitSlot& slot, - llvm::BasicBlock* nextBlock) const = 0; - - // 从 group 读出 intermediate IRRow(extractIntermediateResults) - virtual llvm::Value* loadIntermediate(HashAggrJitCodegen& codegen, - llvm::Value* group, - const HashAggrJitSlot& slot) const = 0; - - // 从 group 读出 final IRRow(extractResults) - virtual llvm::Value* finalize(HashAggrJitCodegen& codegen, - llvm::Value* group, - const HashAggrJitSlot& slot) const = 0; - - virtual bool canExtract(const HashAggrJitSlot& slot, - bool partialOutput) const = 0; -}; -``` - -**关键约束**: -1. `loadIntermediate` 返回的 IRRow 类型 = `IRRow::getType(intermediateValueType())`;`finalize` 返回 = `IRRow::getType(finalValueType())`。 -2. group 内 memory layout **不是 `GroupOps` 自己分配/注册**的;它依旧来源于现有 accumulator/state layout,`GroupOps` 只是通过 `slot + state field offset` 去访问。 -3. 第一阶段 `GroupOps` 可以是当前 `HashAggrJitOps` 的**薄 facade**:先把“状态读写逻辑”从 aggregate ops 中理顺,不要求第一步就重写整个 JIT chunk 生成框架。 -4. **null 处理统一在这一层完成**:`accumulate / merge` 必须显式处理 `IRRow::getIsNull(input)`,框架不再依赖任何外部状态。 -5. `nextBlock` 仍作为参数保留,是为了兼容当前 `genAddDenseIR(...)` 的控制流拼装方式;不要为了追求接口漂亮而强行重写外层 loop/branch 骨架。 - -### 3.3.1 与当前 `HashAggrJitOps` 的映射 - -为了降低迁移风险,建议第一阶段直接保持与现有 `HashAggrJitOps` 一一对应: - -| 当前接口 | 收敛后的职责 | -|----------|--------------| -| `initGroup` | `GroupOps::init` | -| `addRawInput` | `InputAdapter::read(raw)` → `GroupOps::accumulate` | -| `addIntermediateResults` | `InputAdapter::read(intermediate)` → `GroupOps::merge` | -| `canExtract` | `GroupOps::canExtract` | -| `extract` | `GroupOps::loadIntermediate/finalize` → `OutputAdapter::write` | - -也就是说,**第一步不是删掉 `HashAggrJitOps`,而是让它退化为一个桥接层**: - -- 对外仍维持当前 JIT chunk 代码生成入口; -- 对内逐步把输入读取 / group 状态访问 / 输出写回转发到新三层; -- 等所有 aggregate 都迁完后,再决定是否彻底折叠旧表结构。 - -### 3.4 `OutputAdapter` —— IRRow → Vector - -```cpp -class OutputAdapter { - public: - virtual ~OutputAdapter() = default; - - virtual llvm::StructType* irRowType(llvm::IRBuilder<>& b) const = 0; - - // 把 IRRow 写入 vector[row_idx] - virtual void write(llvm::IRBuilder<>& b, - llvm::Value* vector_ctx, - llvm::Value* row_idx, - llvm::Value* irrow) const = 0; -}; -``` - -输出端通常只需要 `FlatOutputAdapter` 和 `RowOutputAdapter`(写复合 intermediate)。 - ---- - -## 4. 各聚合落地示例 - -### 4.1 `sum` / `sum`(最简单) - -```cpp -class SumGroupOps : public GroupOps { - llvm::Type* intermediateValueType(IRBuilder& b) const override { return b.getInt64Ty(); } - llvm::Type* finalValueType(IRBuilder& b) const override { return b.getInt64Ty(); } - - void init(HashAggrJitCodegen& codegen, - Value* group, - const HashAggrJitSlot& slot) const override { - codegen.setAccumulatorNull(group, slot); - codegen.storeValue(group, codegen.builder().getInt64Ty(), slot.offset, - codegen.builder().getInt64(0)); - } - - void accumulate(HashAggrJitCodegen& codegen, - Value* group, - Value* in, - const HashAggrJitSlot& slot, - BasicBlock*) const override { - auto& b = codegen.builder(); - auto* in_null = IRRow::getIsNull(b, in); - auto* in_val = IRRow::getValue(b, in); - // if (!in_null) { sum += in_val; is_null = false; } - BasicBlock *if_t = ..., *cont = ...; - b.CreateCondBr(b.CreateNot(in_null), if_t, cont); - b.SetInsertPoint(if_t); - auto* old = codegen.loadValue(group, b.getInt64Ty(), slot.offset); - codegen.storeValue(group, b.getInt64Ty(), slot.offset, b.CreateAdd(old, in_val)); - codegen.clearAccumulatorNull(group, slot); - b.CreateBr(cont); - b.SetInsertPoint(cont); - } - - void merge(HashAggrJitCodegen& codegen, - Value* group, - Value* in, - const HashAggrJitSlot& slot, - BasicBlock* next) const override { - accumulate(codegen, group, in, slot, next); - } - - Value* loadIntermediate(HashAggrJitCodegen& codegen, - Value* group, - const HashAggrJitSlot& slot) const override { - auto& b = codegen.builder(); - return IRRow::pack( - b, - codegen.loadValue(group, b.getInt64Ty(), slot.offset), - codegen.isAccumulatorNull(group, slot)); - } - Value* finalize(HashAggrJitCodegen& codegen, - Value* group, - const HashAggrJitSlot& slot) const override { - return loadIntermediate(codegen, group, slot); - } - - bool canExtract(const HashAggrJitSlot&, bool) const override { - return true; - } -}; -``` - -### 4.2 `avg`(复合 value,**采用嵌套**) - -```cpp -class AvgGroupOps : public GroupOps { - // intermediate value = { double sum, i64 count };is_null 由框架包外层 - llvm::Type* intermediateValueType(IRBuilder& b) const override { - return llvm::StructType::get(b.getDoubleTy(), b.getInt64Ty()); - } - llvm::Type* finalValueType(IRBuilder& b) const override { return b.getDoubleTy(); } - - // 注意:这里不是重新定义 group layout,而是复用现有 accumulator/state layout。 - // 当前 bolt 中 avg 仍应与 JitAvgState / slot.offset / kAvgCountOffset 保持一致, - // 避免第一阶段重构把 state ABI 一起打散。 - - void init(HashAggrJitCodegen& codegen, - Value* group, - const HashAggrJitSlot& slot) const override { - ... // 对齐当前 compileAvgInitGroup:setAccumulatorNull + sum/count 初始化 - } - - // raw input: IRRow_t = { double, i1 } - void accumulate(HashAggrJitCodegen& codegen, - Value* group, - Value* in, - const HashAggrJitSlot& slot, - BasicBlock*) const override { - auto& b = codegen.builder(); - auto* in_null = IRRow::getIsNull(b, in); - auto* in_val = IRRow::getValue(b, in); - // if (!in_null) { sum += val; count += 1; } - ... - } - - // intermediate: IRRow_t = { {double, i64}, i1 } - void merge(HashAggrJitCodegen& codegen, - Value* group, - Value* in, - const HashAggrJitSlot& slot, - BasicBlock* nextBlock) const override { - auto& b = codegen.builder(); - auto* in_null = IRRow::getIsNull(b, in); - auto* part_sum = IRRow::getValueField(b, in, 0); // double - auto* part_count = IRRow::getValueField(b, in, 1); // i64 - // if (!in_null) { sum += part_sum; count += part_count; } - ... - } - - Value* loadIntermediate(HashAggrJitCodegen& codegen, - Value* group, - const HashAggrJitSlot& slot) const override { - auto& b = codegen.builder(); - auto* count = loadCount(b, group); - auto* sum = loadSum(b, group); - auto* is_null = b.CreateICmpEQ(count, b.getInt64(0)); - // 构造嵌套 struct { double, i64 } - auto* inner_ty = intermediateValueType(b); - auto* inner = b.CreateInsertValue(UndefValue::get(inner_ty), sum, {0}); - inner = b.CreateInsertValue(inner, count, {1}); - return IRRow::pack(b, inner, is_null); - } - - Value* finalize(HashAggrJitCodegen& codegen, - Value* group, - const HashAggrJitSlot& slot) const override { - auto& b = codegen.builder(); - auto* count = loadCount(b, group); - auto* sum = loadSum(b, group); - auto* is_null = b.CreateICmpEQ(count, b.getInt64(0)); - auto* avg = b.CreateFDiv(sum, b.CreateSIToFP(count, b.getDoubleTy())); - return IRRow::pack(b, avg, is_null); - } - - bool canExtract(const HashAggrJitSlot& slot, bool partialOutput) const override { - return ...; // 第一阶段直接镜像当前 canCompileAvgExtract 语义 - } -}; -``` - -`InputAdapter` 端只需要: -- raw 输入 → `DecodedScalarInputAdapter` (IRRow = `{double, i1}`) -- intermediate 输入 → `DecodedRowInputAdapter`,读取规范化后的 `rowField* / decodedVector(fallback)`,自动构造嵌套 IRRow - -### 4.3 `count` - -```cpp -class CountGroupOps : public GroupOps { - llvm::Type* intermediateValueType(IRBuilder& b) const override { return b.getInt64Ty(); } - llvm::Type* finalValueType(IRBuilder& b) const override { return b.getInt64Ty(); } - // count 永远不是 null,is_null 字段恒为 false(LLVM 会优化掉) - // ... -}; -``` - -> 实现上应继续贴合当前 bolt:`count(*)` 的 raw-input 路径本质是 `+1`,而非真的去读取一个输入列; -> merge 路径则读取 intermediate bigint。不要为了统一接口而把 `count(*)` 硬塞进一个虚构输入列模型里。 - -### 4.4 `min / max` - -类似 sum,把 `Add` 换成 `select(cmp, old, new)` 即可。 - -### 4.5 `stddev`(前瞻验证,体现可扩展性) - -```cpp -llvm::Type* intermediateValueType(IRBuilder& b) const override { - return llvm::StructType::get(b.getInt64Ty(), // count - b.getDoubleTy(), // mean - b.getDoubleTy()); // M2 -} -// IRRow_t = { {i64, double, double}, i1 },框架完全无需改动 -``` - ---- - -## 5. 与现有代码的对接 - -### 5.1 删除项 - -| 文件 / 符号 | 处置 | -|-------------|------| -| `struct HashAggrJitDecodedInput` | 第一阶段**不删除**;先把它收敛为 canonical decoded descriptor,供 `InputAdapter` 消费;所有 aggregate 迁移完成后再决定是否改名/瘦身 | -| `struct HashAggrJitOutput` | 第一阶段**不删除**;先把它收敛为 canonical output descriptor,供 `OutputAdapter` 消费;待 extract 全迁完后再决定是否改名/瘦身 | -| 各 `XxxAggregate::codegenAddDense` 中关于 vector decoding 的代码 | 迁到 `InputAdapter` | -| 各 `XxxAggregate::codegenAddDense` 中关于 group rw 的代码 | 迁到 `GroupOps` | -| `Aggregate::numNulls_` 的更新依赖(JIT path) | 删除依赖 | -| 任何 `extractvalue` / `insertvalue` 直接对 IRRow 做的代码 | 替换成 `IRRow::*` helper | - -> 额外说明:第一阶段的目标是**理顺职责边界**,不是立即改变 group row 的底层存储协议; -> `slot.offset/nullByte/nullMask` 与现有 state struct offset 仍然是合法的迁移期依赖。 - -### 5.2 新增项 - -``` -velox/exec/jit/ -├── IRRow.h -├── InputAdapter.h -├── input_adapters/ -│ ├── DecodedScalarInputAdapter.h -│ ├── DecodedRowInputAdapter.h -│ └── CountStarInputAdapter.h -├── GroupOps.h -├── group_ops/ -│ ├── SumGroupOps.{h,cpp} -│ ├── CountGroupOps.{h,cpp} -│ ├── AvgGroupOps.{h,cpp} -│ ├── MinMaxGroupOps.{h,cpp} -│ └── ... -├── OutputAdapter.h -└── output_adapters/ - ├── FlatOutputAdapter.h - └── RowOutputAdapter.h -``` - -### 5.3 单测要求 - -新增测试 `HashAggrJitFrameworkTest.cpp`,必须覆盖: - -1. `IRRow::pack/getValue/getIsNull` 在简单类型与嵌套类型上 round-trip。 -2. `DecodedScalarInputAdapter` 在 flat/constant/dictionary 三种 batch encoding 上生成**同一形状 IR**,并通过不同的 `indices/nulls/values` runtime 内容得到正确结果。 -3. `DecodedRowInputAdapter` 在 row-field raw fast path 与 helper fallback 两条路径上结果一致。 -4. 每个 `GroupOps`:init → accumulate(若干 raw + 若干 null) → loadIntermediate → merge(到另一 group) → finalize 与 reference 实现一致。 -5. **专项 null 测试**:所有输入都是 null 时,`finalize` 必须返回 `is_null = true`。 -6. avg intermediate 必须验证 IRRow 的 LLVM type 字面就是 `{ {double, i64}, i1 }`(而非平铺)。 - ---- - -## 6. 嵌套 vs 平铺:决策记录(avg 等复合 value) - -| 维度 | 嵌套 `{{double,i64}, i1}` ✅ | 平铺 `{double, i64, i1}` ❌ | -|------|----------------------------|---------------------------| -| `IRRow = {T, i1}` invariant | 保持 | 破坏 | -| `IRRow::getValue / getIsNull` 是否通用 | 是(`{0}` / `{1}`) | 否,avg 要 special case | -| 框架对 T 的内部结构 | 不感知 | 必须知道 field 数 | -| 新增复合 value 聚合(stddev/HLL/...) | 0 改动 | 框架每次都要扩展 | -| Memory layout | 24B(offset 0/8/16) | 24B(offset 0/8/16),完全相同 | -| LLVM lowering 性能 | 经 SROA/InstCombine 后与平铺一致 | 与嵌套一致 | -| IR 可读性 | 略冗长(多一层 `{0,k}`) | 更短 | - -**结论**:嵌套方案在抽象一致性、可扩展性上完胜,且无任何性能代价。**全部聚合统一采用嵌套布局。** - ---- - -## 7. 落地步骤(建议 PR 顺序) - -### 7.0 首个最小可实施 patch(本轮直接落地) - -为了避免第一步就同时改动 chunk ABI、aggregate ops table、descriptor 字段和 benchmark 口径,首个 patch 只做**最小且可验证**的框架落点: - -#### Patch-1 范围 - -1. 新增 `IRRow` helper(建议先放在现有 `bolt/jit/aggregation/` 目录下,而不是一开始新建整套 framework 目录): - - `getType` - - `pack` - - `getValue` - - `getIsNull` - - `withValue` - - `withIsNull` - - `getValueField` -2. 新增对应单测,至少覆盖: - - 标量 value 的 round-trip; - - 嵌套 value(如 `{{double, i64}, i1}`)的 round-trip; - - `withValue / withIsNull` 的覆盖更新语义; -3. **不修改**当前 `HashAggrJitChunk` ABI; -4. **不修改** `HashAggrJitDecodedInput / HashAggrJitOutput` 结构; -5. **不迁移**任何 aggregate 到新三层,只把 `IRRow` 作为第一块可复用基建先落进去。 - -#### Patch-1 预期收益 - -- 为后续 `GroupOps::loadIntermediate/finalize` 提供统一返回协议; -- 为 avg / decimal avg 这类复合 value 的嵌套 IRRow 建立稳定 helper; -- 先把最容易验证、最不影响性能的部分单独落地,降低后续 patch 风险。 - -#### Patch-1 验证方式 - -- 编译 `bolt_thrustjit`; -- 若当前配置包含测试,则额外编译并运行 `bolt_thrustjit_test`; -- 该 patch 不应改变任何现有 hash aggr JIT 生成 IR 的行为与性能。 - -#### Patch-2 范围(紧接 Patch-1) - -第二个最小 patch 继续保持“不改 ABI、不改 chunk 骨架”的原则,只做**标量输入读取的一层内部收口**: - -1. 新增一个极薄的 `DecodedScalarInputAdapter` helper; -2. 第一阶段只提供 `readKnownNotNull(...)`: - - 适用于当前外层控制流已经完成 top-level null 过滤后的路径; - - 直接把 `loadDecodedValue(...)` 的结果打包成 `IRRow`,并把 `is_null` 固定为 `false`; -3. 仅选择 `sum` 作为第一个接入对象,把 `SumOps.cpp` 中对标量输入的直接读取改成通过该 helper; -4. 不改 `HashAggrJitChunk`、`genAddDenseIR(...)`、`HashAggrJitDecodedInput` ABI; -5. `readNullable(...)`、`DecodedRowInputAdapter`、`GroupOps facade` 留到后续 patch。 - -#### Patch-2 预期收益 - -- 验证“InputAdapter 是内部 codegen helper,而不是按 batch encoding 分裂 ABI”的设计方向; -- 让后续 `sum/minmax/avg raw-input` 迁移时有统一入口; -- 继续保证热路径不回退:外层 null 分支与现有 tight loop 骨架保持不变。 - -#### Patch-2 验证方式 - -- 编译 `bolt_thrustjit_test`; -- 运行新增 `IRRow` / `DecodedScalarInputAdapter` 相关测试; -- 确认 `SumOps.cpp` 只是把“直接 load decoded scalar”替换成 helper,不改变现有 null 过滤与算子语义。 - -#### Patch-3 范围(延续 Patch-2) - -第三个最小 patch 继续沿用同一迁移策略,把 `DecodedScalarInputAdapter` 的使用从 `sum` 扩展到 `min/max`: - -1. 不新增 ABI; -2. 不修改 `DecodedScalarInputAdapter` 接口; -3. 仅把 `MinMaxOps.cpp` 中 raw-input 标量读取切换为 `DecodedScalarInputAdapter::readKnownNotNull(...)`; -4. 保持当前外层 null 过滤、NaN 处理和比较逻辑不变; -5. 不触碰 merge row-field 路径,不引入 nullable adapter。 - -#### Patch-3 预期收益 - -- 让 `sum/min/max` 三个最基础的标量 raw-input 聚合统一走同一条内部读取入口; -- 进一步验证“InputAdapter 是内部 codegen helper,而不是新的运行时 ABI”; -- 为后续批量迁移其它标量聚合打样。 - -#### Patch-3 验证方式 - -- 编译 `bolt_thrustjit_test`; -- 运行现有 `IRRow` / `DecodedScalarInputAdapter` 相关测试,确认基建未破坏; -- 确认 `MinMaxOps.cpp` 只替换输入读取入口,不改变 min/max 的比较、NaN 语义和 extract 逻辑。 - -#### Patch-4 范围(先补 nullable contract,不立刻接算子) - -第四个最小 patch 只补 `DecodedScalarInputAdapter` 的 nullable contract,本轮**先落 helper 与单测,不接任何 aggregate**: - -1. 新增 `DecodedScalarInputAdapter::readNullable(...)`; -2. helper 负责: - - 读取 `nulls` 指针; - - 在 `nulls == nullptr` 时走非空快速路径; - - 在 `nulls != nullptr` 时按 row bit 判断是否为 null; - - 返回 `IRRow{value, is_null}`; -3. null 行上不要求读取真实 payload,允许写入 typed zero 作为占位值; -4. 本 patch **不修改** `sum/min/max/count/avg` 等 aggregate; -5. 通过一个可执行的 JIT 单测验证 nullable 语义,而不是只做类型级验证。 - -#### Patch-4 预期收益 - -- 正式建立 `DecodedScalarInputAdapter` 的 nullable 语义 contract; -- 为后续把外层 null 分支逐步内聚到 InputAdapter 提供基础; -- 先用单测把“null 行不读取真实 payload、仅传递 is_null”这件事定下来,避免后续改算子时语义摇摆。 - -#### Patch-4 验证方式 - -- 编译 `bolt_thrustjit_test`; -- 运行新增 `readNullable` JIT 语义测试: - - `nulls == nullptr` 时返回真实 value; - - `nulls` bit 置位时返回 `is_null = true` 对应结果; - - 非 null 行仍返回真实 value; -- 继续跑已有 `IRRow` / `DecodedScalarInputAdapter` 基础测试,确认旧 contract 不回退。 - -#### Patch-5 范围(让 sum 消费 nullable IRRow,但先保留外层 null branch) - -第五个最小 patch 开始让真实 aggregate 消费 `readNullable(...)` 产出的 `IRRow`,但仍然坚持“不一次性收掉外层控制流”: - -1. `SumOps.cpp` 的 add/merge 路径统一先读取 `inputRow = DecodedScalarInputAdapter::readNullable(...)`; -2. `sum` 内部通过 `IRRow::getIsNull(inputRow)` 决定是否跳过累加; -3. 现有 `genAddDenseIR(...)` 的 top-level null 过滤分支**保留不动**; -4. 这意味着本 patch 的行为应与当前逻辑保持一致,只是把 `sum` 的内部消费协议收口到 nullable IRRow; -5. 本 patch 不要求立即让外层 null 分支失效或删除。 - -#### Patch-5 预期收益 - -- 第一次验证“真实 aggregate 可以消费 nullable IRRow contract”; -- 为后续是否收掉外层 null 分支提供对照基线; -- 把 `sum` 变成第一个同时兼容 known-not-null 与 nullable 读取 contract 的算子样板。 - -#### Patch-5 验证方式 - -- 编译 `bolt_thrustjit_test`; -- 运行一个最小 `sum-like` JIT 语义测试:null 行返回旧 accumulator,不为 null 时返回 `old + value`; -- 继续运行已有 `IRRow` / `DecodedScalarInputAdapter` 基础测试,确认 helper contract 不回退; -- 确认 `SumOps.cpp` 仍未修改外层 null 过滤框架,仅改变输入消费方式。 - -#### Patch-6 范围(让 min/max 同样消费 nullable IRRow) - -第六个最小 patch 把 Patch-5 在 `sum` 上验证过的模式复制到 `min/max`: - -1. `MinMaxOps.cpp` 的 update 路径改为先读取 `inputRow = DecodedScalarInputAdapter::readNullable(...)`; -2. 通过 `IRRow::getIsNull(inputRow)` 显式跳过 null 行的比较与写回; -3. 非 null 行仍执行原有 min/max 比较、NaN 处理与 accumulator null 清除逻辑; -4. 现有 `genAddDenseIR(...)` 的 top-level null 过滤分支保留不动; -5. 行为与当前实现保持一致,仅把输入消费协议收口到 nullable IRRow。 - -#### Patch-6 预期收益 - -- 让 `sum/min/max` 统一以 nullable IRRow contract 消费输入; -- 进一步验证“先收口消费协议、暂不删外层 null branch”这一渐进模式在带比较/NaN 语义的算子上同样成立; -- 为后续真正收掉外层 null 分支留出一致的算子基线。 - -#### Patch-6 验证方式 - -- 编译 `bolt_thrustjit_test`; -- 运行已有 `IRRow` / `DecodedScalarInputAdapter` / `sum-like` 测试,确认 contract 不回退; -- 确认 `MinMaxOps.cpp` 仅替换输入消费方式,不改变比较、NaN 与 extract 语义,也未触碰外层 null 框架。 - -#### Patch-7 范围(引入最小 FlatOutputAdapter 并让 sum extract 接入) - -前面几个 patch 都在 input 端收口,本 patch 开始对称地在 output 端引入第一块 helper: - -1. 新增最小 `FlatOutputAdapter`(同样是 codegen-time helper,不引入任何运行时 ABI); -2. 只提供 `writeFromIRRow(codegen, output, row, slot, irRow)`: - - 从 `IRRow` 取 value 与 i1 `is_null`; - - 把 `is_null` zext 到 i8; - - 复用现有 `emitFlatValue(...)` 写回 flat 输出; -3. 让 `SumOps.cpp` 的 extract 先用 `IRRow::pack(value, is_null)` 组装,再通过 `FlatOutputAdapter::writeFromIRRow(...)` 写回; -4. 不修改 `HashAggrJitOutput` 结构与 `genExtractIR(...)` 骨架; -5. 仅 `sum` 接入,`min/max/count/avg/decimal` 的 extract 暂不动。 - -#### Patch-7 预期收益 - -- 让 output 端也有一个与 `IRRow` 对齐的统一写回入口; -- 验证“OutputAdapter 也是内部 codegen helper,而非新 ABI”这一设计方向; -- 为后续把更多 extract 收口到 `FlatOutputAdapter` / `RowOutputAdapter` 打样。 - -#### Patch-7 验证方式 - -- 编译 `bolt_thrustjit_test`; -- 编译并运行 `bolt_aggregates_test` 的 `SumTest` 相关用例,确认 sum extract 行为未回归; -- 确认 `SumOps.cpp` 的 extract 仅改写写回入口,flat 输出语义与 null 位写入保持一致。 - -#### Patch-8 范围(min/max 与 count 的 extract 也接入 FlatOutputAdapter) - -继续把 output 端收口扩展到其余标量聚合: - -1. `MinMaxOps.cpp` 的 extract 改为:`IRRow::pack(value, isAccumulatorNull)` → `FlatOutputAdapter::writeFromIRRow(...)`; -2. `CountOps.cpp` 的 extract 改为:`IRRow::pack(value, false)` → `FlatOutputAdapter::writeFromIRRow(...)`(count 永不为 null); -3. 不修改 `HashAggrJitOutput` 结构与 `genExtractIR(...)` 骨架; -4. 行为与当前实现保持一致,只把写回入口统一到 `FlatOutputAdapter`。 - -#### Patch-8 预期收益 - -- 让 `sum/min/max/count` 四个标量聚合的 extract 全部走统一 output 入口; -- 进一步压实“OutputAdapter 是内部 codegen helper”的方向; -- 为后续 partial avg / decimal 等复杂 extract 的 RowOutputAdapter 收口铺路。 - -#### Patch-8 验证方式 - -- 编译 `bolt_thrustjit_test`、`bolt_aggregates_test`; -- 运行 `CountAggregationTest` 与 `SumTest` 相关用例,确认 extract 未引入新回归; -- 已知 `MinMaxTest` 三个 JIT 对照用例失败,本轮先忽略,仅确认未新增其它失败。 - -#### Patch-9 范围(avg 的 final extract 接入 FlatOutputAdapter) - -avg 的 final extract 本质就是把 `avg = sum / count` 写回一个 flat double,与 sum/min/max 同形态,因此先收口 final 分支: - -1. `AvgOps.cpp` 的 `compileAvgExtract` 在 `partialOutput == false` 分支改为:`IRRow::pack(avg, is_null)` → `FlatOutputAdapter::writeFromIRRow(...)`; -2. partial avg 的 ROW 输出(`emitPartialAvgResult`)暂不动,留待后续 `RowOutputAdapter`; -3. 不修改 `HashAggrJitOutput` 结构与 `genExtractIR(...)` 骨架; -4. final avg 的 `count == 0 -> null`、divide 语义保持不变。 - -#### Patch-9 预期收益 - -- 让 sum/min/max/count/avg(final) 的 flat extract 全部走统一 output 入口; -- 把 partial(ROW)与 final(flat)两类 output 路径的边界显式化,为 `RowOutputAdapter` 铺路。 - -#### Patch-9 验证方式 - -- 编译 `bolt_thrustjit_test`、`bolt_aggregates_test`; -- 运行 `AverageAggregationTest` 相关用例,确认 final avg extract 行为未回归; -- 确认 `AvgOps.cpp` 仅改写 final 分支写回入口,partial ROW 输出与 divide/null 语义不变。 - -为了控制每个 MR 的 diff 体积,推荐拆 4 个 MR 提交: - -| # | MR 标题 | 范围 | 依赖 | -|---|---------|------|------| -| 1 | `[jit] Introduce IRRow + canonical Input/Output descriptors + GroupOps facade` | 在现有 `HashAggrJitOps` 外围引入三层抽象与单测,不改 chunk ABI | 无 | -| 2 | `[jit] Migrate sum/count/min/max onto GroupOps + Adapter internals` | 先迁简单标量聚合,外部入口保持兼容 | #1 | -| 3 | `[jit] Migrate avg with nested intermediate IRRow` | avg 落地嵌套方案,保留现有 state layout 与 extract 语义 | #2 | -| 4 | `[jit] Migrate decimal sum/avg and optionally shrink legacy tables` | decimal 收口,并视情况瘦身旧 descriptor / ops table | #3 | - -每个 MR 都要: -- 跑通现有 hash aggr e2e 测试集(重点覆盖含 null 输入的 case)。 -- 跑 micro benchmark 对比,必须 ≥ 当前 `9a65fd2` 性能。 -- LLVM IR dump(`-dump_ir`)肉眼检查 SROA 后是否消除了 alloca。 - ---- - -## 8. 验收标准(Definition of Done) - -- [ ] `HashAggrJitDecodedInput / HashAggrJitOutput` 至少已收敛为 canonical descriptor;若仍保留,也不再继续按新聚合需求横向扩字段。 -- [ ] 任何 IRRow 字段访问只能经过 `IRRow::*` helper(grep `extractvalue.*IRRow|insertvalue.*IRRow` 应为零)。 -- [ ] JIT path 不再读写 `Aggregate::numNulls_`。 -- [ ] avg intermediate 的 LLVM type 等于 `{{double, i64}, i1}`(单测断言)。 -- [ ] 新增 stddev 或任意复合 value 聚合时,**不需要修改 InputAdapter/OutputAdapter/IRRow 等框架边界定义**;最多新增对应 `GroupOps` / state helper。 -- [ ] e2e 性能不回退,TPC-H Q1(avg 重灾区)持平或提升。 - ---- - -## 9. 备注 - -- 本方案与上游 Velox 无 conflict —— Velox 没有 hash aggr JIT,bolt 这块是独立分叉。 -- 如果未来引入 PartialFinal 优化、ROW vector 嵌套加深,IRRow 接口无需改动。 -- 复合 value 聚合超过 3 层嵌套(极少见)时,建议在 `IRRow` 上提供 path-based getter(`getValueByPath({0,1,2})`),但不在本次重构范围内。 - ---- - -## 10. InputAdapter 虚接口重构设计(Approach-1 落地稿) - -这一章收敛“最终想要的效果”——**真正建立 `InputAdapter -> GroupOps -> OutputAdapter` 三层正交骨架,并最终删除 `HashAggrJitDecodedInput`**,而不是继续在旧 descriptor 上横向打补丁。 - -### 10.1 目标边界 - -本轮 InputAdapter 重构必须同时满足: - -1. `InputAdapter` 提供**虚函数接口**; -2. adapter 在**构造时直接接受 vector 输入**,而不是接受 `HashAggrJitDecodedInput` 这类中间拼装物; -3. 第一层实现只分两类: - - `ScalarInputAdapter` - - `RowInputAdapter` -4. JIT IR **不能按 flat / constant / dictionary 三种 encoding 分叉生成**;输入 encoding 差异必须在 adapter 内部先被吸收、收敛; -5. 热路径性能不能回退:不能把“每行一次 helper / 每行一次虚调用”重新引回 add-dense loop; -6. 完成后可以删除 `HashAggrJitDecodedInput`,后续新增聚合也不允许再给它加字段。 - -### 10.2 分层职责 - -#### A. InputAdapter:只负责“把 vector 解释成 IRRow 输入契约” - -InputAdapter 的职责是: - -- 接受 batch 内真实的 `BaseVector` / `RowVector` 输入; -- 在 adapter 内部完成 decode / flatten / indices/nulls 收敛; -- 对 JIT 暴露**稳定、encoding 无关**的 runtime payload; -- 对 codegen 暴露“如何从该 payload 读出 `IRRow`”的统一接口。 - -它**不负责**: - -- 聚合 state layout; -- add / merge / extract 语义; -- 输出 vector 写回; -- 聚合专有字段语义(例如 avg 的 `sum/count`、decimal 的 `isEmpty`)。 - -#### B. GroupOps:只负责“消费/产生 IRRow” - -GroupOps 只看: - -- 输入:`IRRow` 或嵌套 `IRRow` -- 状态:`group + slot.offset` -- 输出:`IRRow` - -也就是说,`sum/min/max/count/avg/decimal` 的差异只体现在各自 ops/state helper 中; -**GroupOps 不拥有 InputAdapter / OutputAdapter 的 ABI,也不拥有 group layout 定义权**。 - -#### C. OutputAdapter:只负责“把 IRRow 写回结果 vector” - -OutputAdapter 只做两件事: - -- flat output:写一个标量 `IRRow`; -- row output:把 `IRRow` 的每个 child field 和 top-level null 写回。 - -它不应理解“这是 avg 的 2-field row”或“这是 decimal sum 的 3-field row”; -字段个数/字段类型来自 `IRRow` 的 payload type,本身不携带聚合语义。 - -### 10.3 运行时对象模型 - -最终运行时不再构造 `HashAggrJitDecodedInput`,而是构造 adapter 对象: - -```cpp -class InputAdapter { - public: - virtual ~InputAdapter() = default; - - // 供 batch 准备阶段调用;完成 decode / flatten / child adapter 建立。 - virtual void prepare() = 0; - - // 返回稳定的、可传入 JIT add_dense 的 runtime payload。 - virtual const void* runtime() const = 0; - - // 返回与该 adapter 对应的 codegen 节点。 - virtual const InputAdapterCodegen& codegen() const = 0; -}; - -class ScalarInputAdapter final : public InputAdapter { ... }; -class RowInputAdapter final : public InputAdapter { ... }; -``` - -这里的关键点是: - -- **虚函数只发生在 batch 准备阶段**; -- JIT 热循环不做 virtual dispatch; -- JIT add-dense 看到的仍然是 `char**` / `void**` 风格的 runtime payload 数组,只是 payload 的拥有者从旧 struct 变成了 adapter; -- 因此性能上仍保持“每行直接 load 指针/indices/nulls”的 fast path。 - -### 10.4 Runtime payload 形状 - -为了替换 `HashAggrJitDecodedInput`,需要把“adapter 对象”和“JIT 可直接 load 的 POD payload”解耦: - -```cpp -struct ScalarInputRuntime { - const void* values; - const int32_t* indices; - const uint64_t* nulls; -}; - -struct RowInputRuntime { - const uint64_t* nulls; - const ScalarInputRuntime* const* children; // scalar child runtimes - int32_t numChildren; -}; -``` - -约束如下: - -- `ScalarInputRuntime` 对应今天 canonical decoded descriptor 的标量子集; -- `RowInputRuntime` 不再内嵌 `rowField0Values/rowField1Values/...` 这种聚合专有字段; -- `RowInputRuntime` 也不再保留 `indices`:row 自身不再承载 dictionary/constant 等 wrapping,若 child 仍需索引映射,应下沉到 child 自己的 scalar runtime; -- 当前阶段 `RowInputRuntime.children[i]` 直接指向 `ScalarInputRuntime`;也就是说,本轮只支持 **row-of-scalars**,不引入递归 row child; -- 若某些 merge fast path 需要 child flat raw 指针,应该由 `ScalarInputAdapter` 自身保证其 runtime payload 已经是可直接 load 的 canonical scalar 形态,而不是再给顶层 row runtime 增加“field0/field1 特例字段”。 - -### 10.5 Codegen 侧接口 - -codegen 层不再把“输入读取”硬编码成 `loadDecodedValue / loadDecodedRowField*` 这一组围绕 `HashAggrJitDecodedInput` 的 offset 访问,而是收敛成: - -```cpp -class InputAdapterCodegen { - public: - virtual ~InputAdapterCodegen() = default; - // 返回 IRRow 中 payload 的 LLVM type。 - virtual llvm::Type* llvmValueType(HashAggrJitCodegen& codegen) const = 0; - - virtual llvm::Value* readIRRow( - HashAggrJitCodegen& codegen, - llvm::Value* runtime, - llvm::Value* row, - const HashAggrJitSlot& slot) const = 0; -}; -``` - -第一阶段只需要两个实现: - -- `ScalarInputAdapterCodegen` - - 从 `ScalarInputRuntime` 读出 `IRRow` - - 覆盖今天 `DecodedScalarInputAdapter::readKnownNotNull/readNullable` 的职责 -- `RowInputAdapterCodegen` - - 从 `RowInputRuntime` 先读出 `children[i]` 对应的 `ScalarInputRuntime*` - - 再通过内部持有的 scalar child readers 读取 child 的 value/null - - 最后组装顶层 `IRRow` - - 覆盖今天 `loadDecodedRowField / loadDecodedRowFieldBool` 这组“按 field 特判”的路径 - -这样 IR 收敛点就从“旧 descriptor 的固定字段”变成“adapter runtime 的稳定形状”。 - -#### 10.5.0.1 `RowInputRuntime.children` 如何读到 child 的 values/nulls - -这是 runtime / codegen 分层里最关键的一点: - -- `RowInputRuntime` **不直接暴露** `child0Values/child0Nulls/...` 这类字段; -- `RowInputRuntime` 只保存 `children[i]` —— 即 `ScalarInputRuntime*`; -- “如何从这个 child runtime 读出 value/null” 由 `RowInputAdapterCodegen` 内部持有的 scalar child readers 决定。 - -推荐读取流程: - -```cpp -auto* childRuntime = rowCodegen.loadScalarChildRuntime(codegen, rowRuntime, i); -auto* childRow = rowCodegen.scalarChildAt(i).readIRRow(codegen, childRuntime, row, slot); -auto* childValue = IRRow::getValue(builder, childRow); -auto* childIsNull = IRRow::getIsNull(builder, childRow); -``` - -也就是说: - -1. row runtime 只负责提供 `children[i]` 指针; -2. 当前阶段 child 固定为 scalar,因此不需要 runtime tag,也不需要递归 row dispatch; -3. child 的 `values/nulls/indices` 在 `ScalarInputRuntime` 内部读取; -4. 因而“读 children 的 values/nulls”不是 `RowInputRuntime` 的接口,而是 `RowInputAdapterCodegen` 调度其 scalar child readers 的结果。 - -这也解释了为什么 row runtime 本身不需要再带: - -- `indices` -- `rowField0Values` -- `rowField1Values` - -因为这些都属于 child 的读取策略,不属于 row root 的职责。 - -#### 10.5.0.2 `InputAdapter` / `InputAdapterCodegen` 应如何调整 - -为了让上面的 child-reading 成立,接口要从“单节点一次性读完”调整成“runtime root + codegen 节点协作”: - -##### 运行时对象层 - -`InputAdapter` 负责两件事: - -1. 构造并拥有 runtime payload; -2. 暴露与之匹配的 codegen 节点。 - -也就是说,`InputAdapter::runtime()` 和 `InputAdapter::codegen()` 必须成对出现。 - -##### codegen 层 - -`InputAdapterCodegen` 基类仍然只需要: - -1. `llvmValueType(...)`:告诉框架当前节点的 payload LLVM type; -2. `readIRRow(...)`:把当前 runtime 解释成 `IRRow`。 - -row-specific 的 child 访问辅助接口不必上提到基类;它们由 `RowInputAdapterCodegen` 自身私有持有即可。这样接口更贴近当前“row-of-scalars”的范围,避免过早泛化。 - -##### `RowInputAdapterCodegen` 的实现语义 - -`RowInputAdapterCodegen::readIRRow(...)` 应按下面语义生成 IR: - -1. 先检查 `rowRuntime.nulls`,得到 top-level row 是否为 null; -2. 若 top-level 为 null,返回 `IRRow{zero_payload, true}`; -3. 若 top-level 非 null,则对每个 child: - - `childRuntime = loadScalarChildRuntime(..., i)` - - `childRow = scalarChildAt(i).readIRRow(..., childRuntime, row, slot)` -4. 组装 payload: - - 若当前业务只需要 child value(例如 avg merge 的 `sum/count` 都是非 null 标量),可把 `IRRow::getValue(childRow)` 填入 payload; - - 当前阶段只支持 scalar child;若未来要支持 nested row child,再重新把 row child codegen 抽象上提。 - -因此,**从 `children` 读 values/nulls 的正确模型,不是给 `RowInputRuntime` 增字段,而是让 `RowInputAdapterCodegen` 持有一组 scalar child readers,并逐个解释 `ScalarInputRuntime`。** - -#### 10.5.1 `genAddDenseIR` 的 LLVM function 接口如何设计 - -结论先说:**可以去掉 `HashAggrJitDecodedInput`,而且 `genAddDenseIR` 的 LLVM ABI 不需要大改;最好的做法是“保留 3 参函数形状,只替换第 3 个参数的语义”**。 - -推荐接口: - -```cpp -using HashAggrJitAddDenseFunc = - void (*)(char** groups, int32_t numRows, char** inputRuntimes); -``` - -对应 LLVM: - -```llvm -define void @jit_HashAggrAddDense( - i8** %groups, - i32 %num_rows, - i8** %input_runtimes) -``` - -也就是说: - -- 参数 1:`groups`,不变; -- 参数 2:`numRows`,不变; -- 参数 3:从今天的 `decodedInputs` 改成 **`inputRuntimes`**; -- `inputRuntimes[slotIndex]` 指向该 slot 对应 InputAdapter 持有的 root runtime payload; -- JIT 函数本身**不知道也不需要知道**这是 C++ 虚对象,只把它当成 adapter-owned POD runtime 根指针来读。 - -这样做的关键收益是: - -1. `HashAggrJitChunk`、ORC JIT function pointer、调用侧大框架都几乎不用改 ABI; -2. `GroupingSet` 只需把 `hashAggrJitDecodedPtrs_` 的元素从“指向 `HashAggrJitDecodedInput`”改成“指向 adapter runtime”; -3. 热循环仍然是 `slotIndex -> load runtime ptr -> 直接 load values/indices/nulls`,不会引入 per-row virtual dispatch。 - -#### 10.5.2 为什么不建议把 LLVM 接口改成 `InputAdapter**` - -不推荐这种形状: - -```cpp -void (*)(char** groups, int32_t numRows, InputAdapter** adapters) -``` - -原因: - -1. JIT 热路径若想通过 `InputAdapter**` 做 virtual call,会直接把虚调用引进每行循环; -2. LLVM 对 C++ vtable/object layout 没有必要也不应该感知; -3. 我们真正需要的是“稳定可 load 的 runtime payload”,而不是对象本身。 - -所以正确分层应该是: - -- **C++ 对象层**:`InputAdapter` 虚接口,负责 batch 准备; -- **JIT ABI 层**:`i8** input_runtimes`,只传 POD payload 指针; -- **codegen 层**:由 slot 绑定的 adapter codegen helper 决定如何解释这个 payload。 - -#### 10.5.3 `genAddDenseIR` 内部如何按 slot 解释第 3 个参数 - -`genAddDenseIR(...)` 的 skeleton 推荐改成: - -```cpp -for each slot i: - runtime = load input_runtimes[i] - if (checkInputNulls && !countStar) { - if (slot.inputCodegen->topLevelIsNull(codegen, runtime, row)) { - goto next_slot; - } - } - addFn(codegen, group, runtime, row, slot, ...) -``` - -这里有两个重要点: - -1. **slot 用哪个 adapter codegen,是编译期常量,不是运行期分派**; -2. 第 3 参始终只是 `i8* runtime`,真正如何解释成 scalar/row runtime,由该 slot 对应的 codegen helper 完成。 - -也就是说,`HashAggrJitOps::AddFn` 仍然可以保持“每个聚合一个 add 函数”的结构,但其参数语义应从: - -```cpp -llvm::Value* decoded -``` - -改成: - -```cpp -llvm::Value* inputRuntime -``` - -然后在 `sum/min/max/avg/...` 的 addFn 内部统一写成: - -```cpp -auto* inputRow = slot.inputCodegen->readIRRow(codegen, inputRuntime, row, slot); -``` - -这样 GroupOps 看到的始终就是 `IRRow`,不再碰 `HashAggrJitDecodedInput` 的字段偏移。 - -#### 10.5.4 运行时 payload 推荐形状 - -推荐 root runtime 只保留两种: - -```cpp -struct ScalarInputRuntime { - const void* values; - const int32_t* indices; - const uint64_t* nulls; -}; - -struct RowInputRuntime { - const uint64_t* nulls; - const ScalarInputRuntime* const* children; -}; -``` - -这里刻意**不再放**: - -- `decodedVector` -- `rowField0Values` -- `rowField1Values` - -因为这些都是把框架重新绑回旧 descriptor / 特定聚合语义的回退路线。 - -avg merge / decimal sum merge 这类历史快路径,应该改为: - -- `RowInputRuntime.children[0]` 指向 field0 的 `ScalarInputRuntime` -- `RowInputRuntime.children[1]` 指向 field1 的 `ScalarInputRuntime` - -这样 JIT 仍然可以直接读 child flat raw values,并不会失去快路径。 - -#### 10.5.5 `genAddDenseIR` 的无 null 快路径怎么保留 - -这部分仍然建议保留今天的双函数模型: - -- `addDense`:会做 top-level null check -- `addDenseNoNull`:不做 top-level null check - -也就是 LLVM ABI 仍是同一个 3 参函数类型,只是生成两份实现。 - -变化点不在函数签名,而在 skeleton 里的 null 判断从: - -```cpp -loadDecodedNulls(decoded) -``` - -变成: - -```cpp -slot.inputCodegen->loadTopLevelNulls(runtime) -// or slot.inputCodegen->topLevelIsNull(...) -``` - -这样 scalar / row 输入都能复用同一套外层 skeleton,而不是把 null 逻辑重新散落到各个聚合实现里。 - -#### 10.5.6 对实现顺序的直接指导 - -因此真正落地时,`genAddDenseIR` 这条线建议按下面顺序改: - -1. 先把 `decodedInputs` 变量/注释/语义重命名为 `inputRuntimes`; -2. 把 `HashAggrJitOps::AddFn` 的 `decoded` 参数语义改成 `inputRuntime`; -3. 在 slot 上挂 compile-time 的 input codegen/helper 信息; -4. 把外层 null gating 改成走 adapter helper; -5. 最后再删 `HashAggrJitDecodedInput`、`offsetof(...)` 常量与 `loadDecoded*` 专名 API。 - -**所以答案是:能去掉,而且最合理的 `genAddDenseIR` 设计不是改成“传 adapter 对象”,而是保留 `void(i8**, i32, i8**)` 形状,把第三个参数升级成 adapter-owned runtime payload 数组。** - -#### 10.5.7 `HashAggrJitDecodedInput` 是否应该改成 union - -这个方向**是可行的,而且比“继续扩一个大 struct”更优**;在当前 bolt hash aggr JIT 这条路径里, -我现在进一步收敛为: - -- **顶层输入 runtime 可以直接用无 tag 的 union root** -- scalar / row 由 **codegen 时已知的 adapter 结构** 决定,而不是由 runtime node 自描述 -- 但 **row 的 child 当前进一步收紧为 scalar-only**,不再让 child 也走统一 union node - -推荐形状: - -```cpp -union HashAggrJitInputRuntime; - -struct HashAggrJitScalarInputRuntime { - const void* values; - const int32_t* indices; - const uint64_t* nulls; -}; - -struct HashAggrJitRowInputRuntime { - const uint64_t* nulls; - const HashAggrJitScalarInputRuntime* const* children; - int32_t numChildren; -}; - -union HashAggrJitInputRuntime { - HashAggrJitScalarInputRuntime scalar; - HashAggrJitRowInputRuntime row; -}; -``` - -##### 为什么 union 方向是对的 - -因为它解决了当前 `HashAggrJitDecodedInput` 最大的问题: - -1. **把 scalar / row 两类输入形状显式分开**,而不是塞进一个横向扩字段的大 struct; -2. `rowField0Values / rowField1Values` 这种“为了某个聚合临时开洞”的模式可以消失; -3. 第三个参数仍然可以是“runtime root 指针数组”,不影响 `genAddDenseIR` 的 3 参 ABI 形状; -4. InputAdapter 的职责能真正落到“从 union runtime 读出 `IRRow`”,而不是继续围绕旧 `DecodedInput` 的字段偏移打补丁。 - -##### 为什么这里可以省掉 shape/tag - -因为当前 add_dense 的生成方式决定了: - -1. 每个 slot 在 codegen 时已经知道输入是 scalar 还是 row; -2. 当前 row 的每个 child 固定为 scalar,child 的读取方式在 codegen 时也是已知的; -3. 热路径不需要 runtime shape dispatch,只需要按已知形状直接 load 对应字段。 - -因此,对 bolt 这条 JIT 路线而言,runtime node 的职责就是“承载值指针/nulls/children(以及 scalar 自己的 indices)”, -而不是“再告诉 JIT 自己是什么类型”。 - -##### 无 tag union 的前提条件 - -无 tag union 成立的前提是: - -1. **不能在热路径做 runtime kind 分派**; -2. slot 必须绑定 compile-time 的 input codegen/helper; -3. row child 的访问路径必须来自已知的 adapter 结构,而不是依赖 runtime 自描述; -4. 若需要 debug/assert,应由构造阶段或非 hot 校验逻辑承担,而不是把 tag 常驻在 runtime node 上。 - -也就是说:**union 是 runtime 承载方式,shape 是 codegen 元信息,不必塞进 runtime node。** - -##### 与 `genAddDenseIR` 接口的关系 - -即便采用无 tag union,`genAddDenseIR` 的 LLVM 接口也**不需要**变成复杂签名;仍建议保持: - -```cpp -void (*)(char** groups, int32_t numRows, char** inputRuntimes) -``` - -或者在 C++ typedef 层写成更强语义版本: - -```cpp -using HashAggrJitAddDenseFunc = - void (*)(char** groups, - int32_t numRows, - HashAggrJitInputRuntime* const* inputRuntimes); -``` - -但 LLVM IR 里依然可以保持 `i8**`,避免 ABI 扩散。 - -##### 什么时候 union 比“分离 root struct + void*”更优 - -我现在更偏向这个简化后的 union 方案,前提是满足下面两点: - -1. 顶层输入 runtime 统一收敛到 scalar/row 两种 root 形状; -2. RowInputRuntime 的 child 固定为 scalar runtime,由 `RowInputAdapterCodegen` 的 scalar child readers 解释成 `IRRow`。 - -因为当前实际需求只覆盖 row-of-scalars,这比“外面全是 `void*`,每层都靠约定 cast”更稳,也比提前支持递归 row 更容易落地。 - -##### 什么时候 union 仍然不够好 - -如果只是把当前这个 struct 生硬改成: - -- 一个 scalar variant -- 一个仍然带 `rowField0/rowField1` 的 row variant - -那仍然不够好,因为这只是把“avg / decimal sum 的聚合语义”从 struct 平铺变成 union variant,**没有真正建立 generic row runtime**。 - -所以采用这一版 union 的最低要求是: - -- row variant 只能有 `nulls/children/numChildren` -- `children` 必须直接指向 `ScalarInputRuntime` -- 不能再出现 `field0/field1` 这种聚合专有字段名 - -##### 结论 - -因此,对“是否可将 `HashAggrJitDecodedInput` 改成 union,并将 union 指针作为 add_dense 第三个参数传入 LLVM function”这个问题,我的结论是: - -- **可以,而且方向是对的;** -- **比继续沿用一个大而全的 struct 更优;** -- **在当前 bolt JIT 路线里,runtime node 可以不带 shape/tag;** -- **并且当前阶段 row variant 应进一步限定为 scalar-children 形状,否则实现复杂度会明显超前于需求。** - -### 10.6 为什么不会导致性能回退 - -性能保护原则: - -1. **不在每行调用虚函数**:virtual dispatch 仅用于 batch 准备; -2. **不在每行调用通用 runtime helper**:常见标量输入仍展开成直接 load `values + indices[row]` / `nulls[row]`; -3. **保留 raw child fast path**:row merge 若 child 已是 flat canonical scalar runtime,codegen 直接读取 child runtime,不退回 `DecodedVector` helper; -4. **让 encoding 差异前置到 adapter 构造**:dictionary/constant/flat 的分歧在 adapter `prepare()` 内吸收,JIT IR 只面对 canonical runtime payload; -5. **外层 add-dense skeleton 不被打散**:只替换“单 slot 如何读 input”,不引入额外 per-row 框架判断。 - -换句话说,InputAdapter 的虚接口是**对象建模边界**,不是热循环执行模型。 - -### 10.7 与当前 patch 序列的衔接 - -当前已经落下的 `IRRow`、`DecodedScalarInputAdapter`、`FlatOutputAdapter` 可以视为最终架构的前置垫片: - -- `IRRow`:保留,作为三层之间唯一值契约; -- `DecodedScalarInputAdapter`:后续升格为 `ScalarInputAdapterCodegen`,不再依附旧 `HashAggrJitDecodedInput` 命名; -- `RowOutputAdapter`:必须保持 generic,只按 struct field 写回,不认 avg 的 2-field 语义; -- `HashAggrJitCodegen::loadDecoded*`:逐步收缩为 adapter runtime 读取 helper,最终删掉 decoded-input 专名 API。 - -### 10.8 建议迁移顺序 - -#### Phase A:先把 codegen 边界改对 - -1. 把当前 `RowOutputAdapter` 改成真正 generic 的 struct writer; -2. 在 `HashAggrJit.h/.cpp` 中引入 input runtime union / adapter codegen 概念; -3. 让 avg partial、decimal merge 等 row 输入/输出先走 generic row contract,而不是 field0/field1 语义 helper。 - -#### Phase B:引入 runtime InputAdapter 对象,但暂不改 add_dense ABI - -1. `GroupingSet` 内部改为构造 `ScalarInputAdapter` / `RowInputAdapter`; -2. adapter 自己持有 runtime payload; -3. 传给 JIT 的仍可先保持 `char** inputs`,但每个元素改为指向 adapter-owned runtime,而不是 `HashAggrJitDecodedInput`。 - -这一阶段完成后,`HashAggrJitDecodedInput` 已经可以从执行路径移除,只剩个别 helper / test 兼容点。 - -#### Phase C:删除旧 descriptor 与旧命名 helper - -1. 删除 `HashAggrJitDecodedInput`; -2. 删除 `loadDecodedValue/loadDecodedNulls/loadDecodedRowField*` 这组旧 API; -3. 测试与 benchmark 一律改用 adapter 构造路径; -4. 清理 `offsetof(HashAggrJitDecodedInput, ...)` 常量与相关 runtime helper。 - -### 10.9 本章对应的 DoD 补充 - -完成 InputAdapter 重构后,应额外满足: - -- [ ] `GroupingSet` 不再直接构造 `HashAggrJitDecodedInput`; -- [ ] JIT add-dense ABI 传递的是 adapter-owned runtime payload; -- [ ] row merge / row extract 不再出现 `rowField0/rowField1` 这类聚合专有字段名; -- [ ] 新增一个 3-field intermediate 聚合时,不需要修改 InputAdapter/OutputAdapter 基类接口。 - ---- - -## 11. 事故复盘:`munmap_chunk(): invalid pointer`(commit `0722a59851` 引入) - -本章记录一次在 `HashAggrJitBenchmark` 上复现的堆破坏崩溃的完整定位过程与根因,作为 output runtime 绑定相关改动的回归警示。 - -### 11.1 现象 - -- 运行 `bolt_hashaggr_jit_benchmark`(RelWithDebInfo 行为,Release preset 构建)必崩。 -- 报错:`munmap_chunk(): invalid pointer`,`SIGABRT`。 -- 栈顶在算子关闭阶段析构中间结果 vector 时: - - `Driver::closeOperators()` → 释放 `RowVector` → 释放其 child `FlatVector`(int64) 的 values buffer → glibc `free` 检测到非法 chunk 指针。 -- hint:bug 出现在最近 5 个 commit 中。 - -### 11.2 定位过程 - -1. **缩小到具体 case**:在 benchmark `addCase()` 的 warmup 处加临时 `fprintf` 打印每个 case 名(每个 case warmup 时会先跑 nojit 再跑 jit)。运行后最后一条输出停在 `width4_merge_decimal_avg` 的 `jit` 阶段,**坐实崩溃 case = `width4_merge_decimal_avg`**。 -2. **bisect 到 commit**:先前已通过 `git reset --hard` 确认 first bad commit = `0722a59851`(其 parent `f752929ecc` 不崩)。 -3. **gdb 观察**: - - 崩溃发生在第二阶段(final aggregation)输出路径 `GroupingSet::runHashAggrJitExtractChunks`。 - - 在 decimal avg 的两个 helper 上下断:`jit_HashAggrExtractPartialDecimalAvg` 被调用 40000 次,但 `jit_HashAggrExtractFinalDecimalAvg` **一次都没进入**就崩了 → 说明堆已在 final 阶段“**extract 绑定阶段**”(`chunk.extract()` 之前)被破坏。 -4. **类型/精度推演**: - - `width4` 用 `DECIMAL(12,2)`(short decimal)。 - - decimal avg 中间 sum 类型按签名 `ROW(DECIMAL(38, a_scale), BIGINT)` → `DECIMAL(38,2)` 是 **long decimal(int128)**。 - - decimal avg final 结果类型 `r_precision=min(38,12+4)=16` → `DECIMAL(16,6)` 是 **short decimal**,存储为 **`FlatVector`**。 - - 但 descriptor 的 `accumulatorKind = Int128`(见 `AverageAggregate.cpp` 的 `DecimalAverageAggregate::createHashAggrJitDescriptor`)。 - -### 11.3 根因 - -`0722a59851` 把 `GroupingSet.cpp` 里的 `hashAggrJitRawOutputValues`(改名为 `hashAggrJitRawOutputData`)的 `Int128` 分支,从父 commit 的 `return nullptr` 改成了: - -```cpp -case jit::HashAggrJitValueKind::Int128: - return vector->asUnchecked>()->mutableRawValues(); -``` - -而 `runHashAggrJitExtractChunks` 的 **scalar final 输出绑定**(`GroupingSet.cpp:1306` 附近)用 `slot.desc.accumulatorKind` 来解释输出列: - -```cpp -.values = hashAggrJitRawOutputData(aggregateVector.get(), slot.desc.accumulatorKind) -``` - -对 decimal avg final:`accumulatorKind == Int128`,但 final 输出列真实类型是 short-decimal `FlatVector`。于是: - -1. 一个真实 `FlatVector` 被 `asUnchecked>()` 强转(类型混淆)。 -2. 调用 `mutableRawValues()`(见 `FlatVector.h:244`):此时 `values_` 是按 int64(8B/elem)分配且非 mutable,函数进入重分配分支: - - 按 `int128`(16B/elem)**重新分配 buffer**; - - `memcpy(newValues, rawValues_, byteSize(length))` 即按 2× 字节数从只有 8B/elem 的旧 buffer **越界读**; - - 把该 vector 的 `values_` / `rawValues_` 替换成 int128 尺寸 buffer。 -3. 这步破坏堆(越界读踩坏相邻 chunk metadata,并把列状态搞乱),最终在算子析构释放该 `RowVector`/`FlatVector` 链时 glibc 报 `munmap_chunk(): invalid pointer`。 - -**为何 parent commit 不崩**:原 `Int128` 分支 `return nullptr`,从不触碰该列 buffer。decimal avg final 真正写入走 helper `jit_HashAggrExtractFinalDecimalAvg`,由 `longDecimal` flag 正确按 int64/int128 写回,**根本不需要这个预取的 raw values 指针**。 - -**关键定性**:crash 由 commit `0722a59851` 的这一行引入(`bolt/exec/GroupingSet.cpp` 内 `hashAggrJitRawOutputData` 的 `Int128` 分支),与 scalar-output 绑定处用 `accumulatorKind` 解释 short-decimal 输出列的错配共同作用。它本质是一个 **`accumulatorKind` ≠ 输出 vector 实际存储类型** 的类型混淆。 - -### 11.4 验证 - -把 `Int128` 分支临时改回 `return nullptr`(仅验证用,注释说明 Int128 scalar/decimal 输出走 helper 的 `vector()`,不读此 raw 指针),重编译运行: - -- `width4/8/16/32_merge_decimal_avg` 全部通过,crash 消失; -- 整个 benchmark 跑完无 `munmap` / `Aborted`。 - -→ 根因实锤。 - -### 11.5 修复(已实施:方案 1) - -采用 §11.5 的方案 1:**scalar output 绑定按输出 vector 真实类型推导 kind**,而非 `accumulatorKind`。 - -`runHashAggrJitExtractChunks` 的 FLAT scalar 输出绑定改为: - -```cpp -const auto outputKind = hashAggrJitOutputValueKind(aggregateVector.get()); -if (!outputKind.has_value()) { - canRunChunk = false; - skipReason = "unsupported scalar output value kind"; - break; -} -... .values = hashAggrJitRawOutputData(aggregateVector.get(), *outputKind) ... -``` - -`hashAggrJitOutputValueKind` 已存在,会按列真实类型(含 short/long decimal)推导 kind,从而保证 `hashAggrJitRawOutputData` 取到的指针宽度与列存储宽度一致,杜绝 int64↔int128 错配重分配。`hashAggrJitRawOutputData` 的 `Int128` 分支保持正常实现(用于真正的 long-decimal/HUGEINT 输出列)。 - -其余备选方向(方案 2/3)未采用,记录备查: - -1. 对走 helper 的 decimal/Int128 输出不预取 raw values(保持 nullptr); -2. 统一约束指针宽度一致。 - -### 11.6 临时改动清理(已完成) - -- `bolt/exec/benchmarks/HashAggrJitBenchmark.cpp`:`addCase()` 内的 `fprintf` case 名打印 —— **已回退**。 -- `bolt/exec/GroupingSet.cpp`:`hashAggrJitRawOutputData` 的 `Int128` 分支临时 `return nullptr` —— **已恢复**为正常实现;正式修复落在 scalar 绑定处(见 §11.5)。 - ---- - -## 12. 优化:ROW merge 输入跳过 per-field null 检查(`readRowFieldValue`) - -### 12.1 背景 - -`RowInputAdapterCodegen::readRowField` 对每个 ROW child 都会生成一段 per-field null 检查 CFG(`row_field_null_check` / `row_field_null_done` + PHI)来产出该 field 的 `is_null`,再 `IRRow::pack(value, is_null)`。 - -但 `addIntermediateResults`(merge)路径上,框架外层 `genAddDenseIR` 已对 **top-level ROW null** 统一发射过 null guard;ROW 内部各 field 是否需要 null 位,取决于具体聚合语义: - -- **avg merge**(`ROW(double sum, bigint count)`):业务上不读 field 的 null,只用 value; -- **decimal sum merge**(`ROW(decimal sum, bool isEmpty)`):仅 **sum** 字段的 null 被用来编码 overflow(JIT partial extract 溢出时 `sumVector->setNull`),`isEmpty` 字段的 null 不被消费; -- **decimal avg merge**(`ROW(decimal sum, bigint count)`):sum/count 的 null 都参与 overflow 判定,**不能跳过**。 - -因此对“null 位未被业务消费”的 field,生成 null 检查 CFG 是纯浪费。由于 `nulls` 指针编译期未知,这段 CFG 在 decimal 路径上 LLVM 往往**折不掉**,既增加 IR 体积也增加实际指令。 - -### 12.2 改动 - -新增 value-only 接口 `InputAdapterCodegen::readRowFieldValue(row, field, kind)`: - -- 语义:只返回 ROW child 的裸值(`llvm::Value*`),**跳过** per-field null 检查 CFG; -- 适用前提:该 field 在当前路径上保证非空(其 null 位不被聚合语义消费); -- `ScalarInputAdapterCodegen`:`BOLT_UNSUPPORTED`(与 `readRowField` 一致); -- `RowInputAdapterCodegen`:直接 `loadChild(field)` → `loadScalarInputValue(...)`,不调 `isRowFieldNull`。 - -调用点改造(严格按 null 是否被消费区分): - -| 调用点 | 处理 | -|--------|------| -| `compileAvgAddIntermediateResults`(sum/count) | 两字段全换 `readRowFieldValue` | -| `compileDecimalSumAddIntermediateResults` — `isEmpty` | 换 `readRowFieldValue` | -| `compileDecimalSumAddIntermediateResults` — `sum` | **保留** `readRowField`(`sumIsNull` 编码 overflow) | -| `compileDecimalAvgAddIntermediateResults` — sum/count | **保留** `readRowField`(两个 null 都参与 overflow 判定) | - -### 12.3 正确性依据 - -JIT 的 partial extract(`HashAggrDecimalRuntime.cpp` 的 `jit_HashAggrExtractPartialDecimalSum/Avg`)在 sum 溢出时会 `sumVector->setNull(row, true)`(整行 ROW 非 null)。下游 merge 正是靠读该 field 的 null 来识别并传播 overflow,故 decimal 的 sum 字段(及 decimal avg 的 count)**必须**保留 `readRowField`。注意:非 JIT 的 `extractAccumulators` 不置 sum-null,但 JIT pipeline 中上游可能是 JIT partial extract,跨阶段契约要求 merge 端兼容 sum=null。 - -### 12.4 性能验证(Release,jit 路径,单位 ms,越小越好) - -| case | baseline | optimized | 变化 | -|------|----------|-----------|------| -| width8_merge_avg | 6.16 | 6.00 | -2.6% | -| width16_merge_avg | 11.08 | 10.68 | -3.6% | -| width32_merge_avg | 21.47 | 20.51 | -4.5% | -| width4_merge_decimal_sum | 8.81 | 8.06 | -8.5% | -| width8_merge_decimal_sum | 15.90 | 15.48 | -2.6% | -| width16_merge_decimal_sum | 30.21 | 29.74 | -1.6% | -| width32_merge_decimal_sum | 61.05 | 60.14 | -1.5% | - -- avg merge:宽度越大收益越明显(每多一列多省一段 null CFG,线性放大); -- decimal sum merge:稳定小幅提升; -- `sum`(标量输入)不受影响;功能无回归,benchmark 无 crash。 - -### 12.5 涉及文件 - -- `bolt/jit/aggregation/HashAggrJit.h`:`InputAdapterCodegen` 新增 `readRowFieldValue` 纯虚 + 两子类声明; -- `bolt/jit/aggregation/HashAggrJit.cpp`:两子类实现; -- `bolt/jit/aggregation/ops/AvgOps.cpp`、`ops/DecimalSumOps.cpp`:按上表切换调用。 - ---- - -## 13. Decimal short/long 专用 helper 修复与性能结论(2026-06-13) - -本轮围绕 decimal sum/avg 的 short/long decimal 判断与 runtime helper 做了两类收敛: - -1. **descriptor 语义统一**:`precision/scale` 表示 intermediate/partial decimal 类型,`auxPrecision/auxScale` 表示 final result decimal 类型;decimal sum 因 partial/final 类型相同,所以 `aux*` 镜像 `precision/scale`。 -2. **short/long decimal 在 codegen 期固定**:不再把 `longDecimal` 作为外部 C++ runtime helper 参数传入,避免 LLVM 无法跨外部函数边界消除无效分支;`emitDecimalSumExtract` / `emitDecimalAvgExtract` 直接按实际输出精度选择 short/long 专用 helper。 - -当前专用 helper 形态: - -```text -jit_HashAggrExtractFinalShortDecimalSum -jit_HashAggrExtractFinalLongDecimalSum -jit_HashAggrExtractPartialShortDecimalSum -jit_HashAggrExtractPartialLongDecimalSum -jit_HashAggrExtractFinalShortDecimalAvg -jit_HashAggrExtractFinalLongDecimalAvg -jit_HashAggrExtractPartialShortDecimalAvg -jit_HashAggrExtractPartialLongDecimalAvg -``` - -验证命令: - -```bash -cmake --build --preset conan-release --target bolt_hashaggr_jit_benchmark --parallel 16 -./_build/Release/bolt/exec/benchmarks/bolt_hashaggr_jit_benchmark --bm_regex='(width8|width16)' -``` - -与用户提供的 baseline 单次结果相比,关键结论如下: - -| case | baseline | 当前 | 变化 | -|---|---:|---:|---:| -| `width8_merge_decimal_sum_jit` | 21.68ms | 15.24ms | **-29.70%** | -| `width16_merge_decimal_sum_jit` | 42.44ms | 29.47ms | **-30.56%** | -| `width8_merge_decimal_avg_jit` | 13.67ms | 13.73ms | +0.44% | -| `width16_merge_decimal_avg_jit` | 26.90ms | 26.45ms | -1.67% | - -汇总: - -| 分组 | 几何平均变化 | -|---|---:| -| 所有 JIT 项 | **-4.23%** | -| decimal JIT 项 | **-16.67%** | -| 非 decimal JIT 项 | -1.98% | - -因此,当前结论是: - -- `decimal_sum_jit` 从 baseline 的“慢于 nojit”变成“明显快于 nojit”: - - width8:15.24ms vs nojit 19.52ms,约 **21.9%** faster; - - width16:29.47ms vs nojit 39.11ms,约 **24.6%** faster。 -- `decimal_avg_jit` 基本持平,无系统性回退。 -- 非 decimal 项与本轮改动无直接关系,单次结果有正有负,整体未观察到系统性退化。 -- 后续 decimal extract/merge 相关重构应继续坚持:**能在 codegen 期确定的类型选择,不要作为 runtime 参数留给外部 helper 分支处理**。 From 24cc7fe6e2bd49159af873d926cffeb359307c0b Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Sun, 21 Jun 2026 22:33:10 +0800 Subject: [PATCH 92/98] fix tidy From 3f4a0432e78d1b8c0d0741b6b8b333447380e796 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 22 Jun 2026 14:20:12 +0800 Subject: [PATCH 93/98] fix diff --- bolt/jit/aggregation/HashAggrJit.cpp | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index c5ae64505..3f093f6ed 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1082,14 +1082,23 @@ std::string HashAggrJitSlot::getDescription() const { } inputs << "]"; + // NOTE: nullByte/nullMask MUST be part of the description because they are + // baked into the generated IR as compile-time constants (see + // clearAccumulatorNull/setAccumulatorNull). The description drives the JIT + // module cache key (functionName_); omitting them lets two slots that share + // the same aggregate semantics and accumulator offset but have a different + // null-bit layout reuse the same compiled code, which would read-modify-write + // the wrong null bit and corrupt a neighboring grouping key's null flag. return fmt::format( - "{}_raw{}_partial{}({})->{}@{}", + "{}_raw{}_partial{}({})->{}@{}@nb{}@nm{}", hashAggrJitKindName(desc.kind), desc.context.isRawInput, desc.context.isPartialOutput, inputs.str(), desc.context.outputType()->toString(), - offset); + offset, + nullByte, + nullMask); } HashAggrJitChunk::HashAggrJitChunk(std::vector slots) @@ -1099,6 +1108,18 @@ HashAggrJitChunk::HashAggrJitChunk(std::vector slots) "jit_hashaggr_v2_n{}_h{:016x}", slots_.size(), bits::hashBytes(1, description.data(), description.size())); + // TODO(hash_aggr_jit): temporary diagnostics to confirm/rule out JIT module + // cache collisions across grouping sets that share the same description but + // have different null-bit layouts. Prints the cache key together with each + // slot's offset/nullByte/nullMask. Remove once the root cause is confirmed. + for (const auto& slot : slots_) { + LOG(INFO) << "HashAggrJit slot layout: functionName=" << functionName_ + << " description=" << description + << " aggregateIndex=" << slot.aggregateIndex + << " offset=" << slot.offset << " nullByte=" << slot.nullByte + << " nullMask=0x" << std::hex << static_cast(slot.nullMask) + << std::dec; + } } std::string HashAggrJitChunk::getDescription() const { From aa589fc7df6dce3115385356c50ce1eb91359c50 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 22 Jun 2026 15:23:56 +0800 Subject: [PATCH 94/98] fix diff caused by basevector::numNulls_ --- bolt/exec/GroupingSet.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index e8912707c..c387e5fa0 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -231,6 +231,20 @@ bool fillHashAggrJitRowOutputRuntime( return true; } +void resetHashAggrJitOutputDataDependentFlags( + BaseVector* vector, + const jit::HashAggrJitSlot& slot) { + vector->resetDataDependentFlags(nullptr); + if (slot.desc.outputShape() != jit::HashAggrJitRuntimeShape::Row) { + return; + } + + auto* rowVector = vector->asUnchecked(); + for (auto i = 0; i < rowVector->childrenSize(); ++i) { + rowVector->childAt(i)->resetDataDependentFlags(nullptr); + } +} + #endif std::optional makeHashAggrJitSlot( @@ -1347,6 +1361,9 @@ void GroupingSet::runHashAggrJitExtractChunks( chunk.extract(groups.data(), groups.size(), resultPtrs.data()); } for (const auto& slot : chunk.slots()) { + resetHashAggrJitOutputDataDependentFlags( + result->childAt(slot.aggregateIndex + aggregateOutputOffset).get(), + slot); jitExtracted[slot.aggregateIndex] = 1; } } From 62c8f5275636289e1e5500542044f6e2bd374474 Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 22 Jun 2026 15:44:49 +0800 Subject: [PATCH 95/98] remove useless codes --- bolt/jit/aggregation/HashAggrJit.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/bolt/jit/aggregation/HashAggrJit.cpp b/bolt/jit/aggregation/HashAggrJit.cpp index 3f093f6ed..eb1387ffe 100644 --- a/bolt/jit/aggregation/HashAggrJit.cpp +++ b/bolt/jit/aggregation/HashAggrJit.cpp @@ -1108,18 +1108,6 @@ HashAggrJitChunk::HashAggrJitChunk(std::vector slots) "jit_hashaggr_v2_n{}_h{:016x}", slots_.size(), bits::hashBytes(1, description.data(), description.size())); - // TODO(hash_aggr_jit): temporary diagnostics to confirm/rule out JIT module - // cache collisions across grouping sets that share the same description but - // have different null-bit layouts. Prints the cache key together with each - // slot's offset/nullByte/nullMask. Remove once the root cause is confirmed. - for (const auto& slot : slots_) { - LOG(INFO) << "HashAggrJit slot layout: functionName=" << functionName_ - << " description=" << description - << " aggregateIndex=" << slot.aggregateIndex - << " offset=" << slot.offset << " nullByte=" << slot.nullByte - << " nullMask=0x" << std::hex << static_cast(slot.nullMask) - << std::dec; - } } std::string HashAggrJitChunk::getDescription() const { From 225055b542e015f2351994ddbac4b4ef810145bd Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 22 Jun 2026 16:10:21 +0800 Subject: [PATCH 96/98] fix aggJitCodegenTimeNs zero issue --- bolt/exec/GroupingSet.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bolt/exec/GroupingSet.h b/bolt/exec/GroupingSet.h index e6c0f8a83..e49837242 100644 --- a/bolt/exec/GroupingSet.h +++ b/bolt/exec/GroupingSet.h @@ -216,6 +216,9 @@ class GroupingSet { } common::AggregationStats getRuntimeStats() { +#ifdef ENABLE_BOLT_JIT + waitForHashAggrJitCompilation(); +#endif return stats_; } From 4e5e493934b715cda6fadbc529102069f6e51c5f Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 22 Jun 2026 20:27:38 +0800 Subject: [PATCH 97/98] remove useless logs --- bolt/exec/GroupingSet.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index c387e5fa0..377340df8 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1125,8 +1125,12 @@ void GroupingSet::runHashAggrJitAddChunks( const RowVectorPtr& input, bool mayPushdown, std::vector& jitExecuted) { - if (hashAggrJitChunks_.empty() || hasSpilled() || bypassProbeHT_ || - supportRowBasedOutput_ || !activeRows_.isAllSelected()) { + if (hashAggrJitChunks_.empty()) { + return; + } + + if (hasSpilled() || bypassProbeHT_ || supportRowBasedOutput_ || + !activeRows_.isAllSelected()) { LOG(INFO) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() << " hasSpilled=" << hasSpilled() << " bypassProbeHT=" << bypassProbeHT_ @@ -1262,8 +1266,11 @@ void GroupingSet::runHashAggrJitExtractChunks( const RowVectorPtr& result, int32_t aggregateOutputOffset, std::vector& jitExtracted) { - if (hashAggrJitChunks_.empty() || groups.empty() || hasSpilled() || - supportRowBasedOutput_) { + if (hashAggrJitChunks_.empty()) { + return; + } + + if (groups.empty() || hasSpilled() || supportRowBasedOutput_) { LOG(INFO) << "HashAggrJit extract skipped: chunks=" << hashAggrJitChunks_.size() << " groups=" << groups.size() << " hasSpilled=" << hasSpilled() << " supportRowBasedOutput=" << supportRowBasedOutput_; From 23390a0aa64b494e2a17c2a13b86a68bfbcb4e9a Mon Sep 17 00:00:00 2001 From: "liyang.127" Date: Mon, 22 Jun 2026 20:59:56 +0800 Subject: [PATCH 98/98] remove useless logs --- bolt/exec/GroupingSet.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/bolt/exec/GroupingSet.cpp b/bolt/exec/GroupingSet.cpp index 377340df8..0a10781e1 100644 --- a/bolt/exec/GroupingSet.cpp +++ b/bolt/exec/GroupingSet.cpp @@ -1131,11 +1131,12 @@ void GroupingSet::runHashAggrJitAddChunks( if (hasSpilled() || bypassProbeHT_ || supportRowBasedOutput_ || !activeRows_.isAllSelected()) { - LOG(INFO) << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() - << " hasSpilled=" << hasSpilled() - << " bypassProbeHT=" << bypassProbeHT_ - << " supportRowBasedOutput=" << supportRowBasedOutput_ - << " activeRowsAllSelected=" << activeRows_.isAllSelected(); + LOG_FIRST_N(INFO, 10) + << "HashAggrJit add skipped: chunks=" << hashAggrJitChunks_.size() + << " hasSpilled=" << hasSpilled() + << " bypassProbeHT=" << bypassProbeHT_ + << " supportRowBasedOutput=" << supportRowBasedOutput_ + << " activeRowsAllSelected=" << activeRows_.isAllSelected(); return; }