[VPlan] Use control flow to implement MaskedCond and preserve SSA#201784
[VPlan] Use control flow to implement MaskedCond and preserve SSA#201784lukel97 wants to merge 20 commits into
Conversation
|
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-analysis Author: Luke Lau (lukel97) ChangesStacked on #201783 If we have an early exit loop with non-dereferenceable loads after the exit, we currently bail: int z;
for (int i = 0; i < N; i++) {
if (x[i])
break;
z = y[i];
}If the early exit block dominates the block containing these loads, we could predicate these loads. This PR helps prepare for this by modelling the act of taking an early exit in the VPlan CFG, so that blocks can be automatically predicated by VPlanPredicator. The idea is similar to how tail folding was modelled explicitly in the VPlan CFG in #176143, except we branch to a minimal latch from the exiting block on an exit mask flowchart TD
header["header"]
exiting["exiting"]
predicated["predicated"]
latch["latch"]
header --> exiting
exiting -->| exitmask | predicated
predicated --> latch
exiting -->| !exitmask | latch
This approach works fine for non-predicated early exits, but for early exits with predication VPInstruction::MaskedCond ends up producing redundant masks stemming from the extra predication at each exiting block. The solution to this is to preserve SSA form by inserting phi nodes where needed. With the improvements to VPPredicator the exit mask folds away entirely. Maintaining SSA form also has the benefit that we can remove a lot of the side-steps for VPInstruction::MaskedCond in VPlanVerifier.cpp. The exit mask never gets materialized in this PR since legality checks prevent us from needing to attach the mask to any instruction. This can be enabled afterwards in another PR. Patch is 70.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/201784.diff 18 Files Affected:
diff --git a/llvm/include/llvm/Analysis/DominanceFrontier.h b/llvm/include/llvm/Analysis/DominanceFrontier.h
index fd38891e901e3..4a8ab96cf71a7 100644
--- a/llvm/include/llvm/Analysis/DominanceFrontier.h
+++ b/llvm/include/llvm/Analysis/DominanceFrontier.h
@@ -78,6 +78,7 @@ class DominanceFrontierBase {
const_iterator end() const { return Frontiers.end(); }
iterator find(BlockT *B) { return Frontiers.find(B); }
const_iterator find(BlockT *B) const { return Frontiers.find(B); }
+ const_iterator find(const BlockT *B) const { return Frontiers.find(B); }
/// print - Convert to human readable form
///
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b740665fe70bf..8d6be4aa3e465 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1380,9 +1380,6 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
/// Returns true if the VPInstruction does not need masking.
bool alwaysUnmasked() const {
- if (Opcode == VPInstruction::MaskedCond)
- return false;
-
// For now only VPInstructions with underlying values use masks.
// TODO: provide masks to VPInstructions w/o underlying values.
if (!getUnderlyingValue())
diff --git a/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h b/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
index 2864670f44913..1ad522880c709 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
@@ -18,6 +18,8 @@
#include "VPlan.h"
#include "VPlanCFG.h"
#include "llvm/ADT/GraphTraits.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/DominanceFrontierImpl.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Support/GenericDomTree.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
@@ -67,5 +69,11 @@ template <>
struct GraphTraits<const VPDomTreeNode *>
: public DomTreeGraphTraitsBase<const VPDomTreeNode,
VPDomTreeNode::const_iterator> {};
+
+class VPPostDominanceFrontier
+ : public DominanceFrontierBase<VPBlockBase, true> {
+public:
+ explicit VPPostDominanceFrontier(const DomTreeT &VPDT) { analyze(VPDT); }
+};
} // namespace llvm
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANDOMINATORTREE_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp b/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
index 2717b80e2eeaa..2ec3df8ccf8c1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
@@ -34,6 +34,9 @@ class VPPredicator {
/// Post-dominator tree for the VPlan.
VPPostDominatorTree VPPDT;
+ /// Post-dominator frontier for the VPlan.
+ VPPostDominanceFrontier VPPDF;
+
/// When we if-convert we need to create edge masks. We have to cache values
/// so that we don't end up with exponential recursion/IR.
using EdgeMaskCacheTy =
@@ -69,8 +72,19 @@ class VPPredicator {
return EdgeMaskCache[{Src, Dst}] = Mask;
}
+ using EdgeTy = std::pair<const VPBasicBlock *, const VPBasicBlock *>;
+
+ /// Compute the "furthest up" set of edges for each incoming value of \Phi.
+ MapVector<EdgeTy, VPValue *> computeBlendEdges(VPPhi *Phi);
+
+ /// Given a set of \p Edges that lead to \p VPBB, return the OR of all edges
+ /// or an equivalent block in-mask.
+ VPValue *createMaskDisjunction(ArrayRef<EdgeTy> Edges, VPBasicBlock *VPBB);
+
+ DenseMap<const VPBasicBlock *, VPBasicBlock::iterator> InsertPoints;
+
public:
- VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan) {}
+ VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan), VPPDF(VPPDT) {}
/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(const VPBasicBlock *VPBB) const {
@@ -136,6 +150,10 @@ void VPPredicator::createBlockInMask(VPBasicBlock *VPBB) {
// Start inserting after the block's phis, which be replaced by blends later.
Builder.setInsertPoint(VPBB, VPBB->getFirstNonPhi());
+ // Keep track of where in VPBB we are inserting the masks into.
+ scope_exit UpdateInsertPoint(
+ [this, &VPBB]() { InsertPoints[VPBB] = Builder.getInsertPoint(); });
+
// Reuse the mask of the immediate dominator if the VPBB post-dominates the
// immediate dominator.
auto *IDom = VPDT.getNode(VPBB)->getIDom();
@@ -224,7 +242,117 @@ void VPPredicator::createSwitchEdgeMasks(const VPInstruction *SI) {
setEdgeMask(Src, DefaultDst, DefaultMask);
}
+// Compute the "furthest up" set of edges for each incoming value of a phi.
+//
+// Start by keeping track of what edges lead to which value. Then see if any
+// node has the same value for all outgoing edges. If so then propagate that
+// value up to every node it postdominates.
+MapVector<VPPredicator::EdgeTy, VPValue *>
+VPPredicator::computeBlendEdges(VPPhi *Phi) {
+ MapVector<EdgeTy, VPValue *> Edges;
+
+ // Mark the given edge as providing the value \p V.
+ auto AddEdge = [&Edges](const VPBlockBase *From, const VPBlockBase *To,
+ VPValue *V) {
+ EdgeTy Edge = {cast<VPBasicBlock>(From), cast<VPBasicBlock>(To)};
+ assert((!Edges.contains(Edge) || Edges.lookup(Edge) == V) &&
+ "Clobbering an edge?");
+ Edges[Edge] = V;
+ };
+
+ for (auto [InVal, InVPBB] : Phi->incoming_values_and_blocks())
+ AddEdge(InVPBB, Phi->getParent(), InVal);
+
+ // The root phi must postdominate every incoming block. Also don't touch
+ // phis in a reduction chain since they need to be in a specific structure
+ // for handle*Reductions.
+ for (auto [InVal, InVPBB] : Phi->incoming_values_and_blocks())
+ if (!VPPDT.dominates(Phi->getParent(), InVPBB) ||
+ isa<VPReductionPHIRecipe>(InVal))
+ return Edges;
+
+ // Given a list of edges, check if they all have the same value and return it.
+ auto GetAllEqual = [&Edges](ArrayRef<EdgeTy> OutEdges) -> VPValue * {
+ VPValue *Common = nullptr;
+ for (EdgeTy E : OutEdges) {
+ VPValue *V = Edges.lookup(E);
+ if (!V)
+ return nullptr;
+ if (match(V, m_Poison()))
+ continue;
+ if (!Common)
+ Common = V;
+ else if (Common != V)
+ return nullptr;
+ }
+ return Common;
+ };
+
+ SetVector<const VPBlockBase *> Worklist(from_range, Phi->incoming_blocks());
+ while (!Worklist.empty()) {
+ auto *VPBB = cast<VPBasicBlock>(Worklist.pop_back_val());
+
+ // Check that all outgoing edges from VPBB have the same value.
+ SmallVector<EdgeTy> OutEdges;
+ for (const VPBlockBase *Succ : VPBB->getSuccessors())
+ OutEdges.emplace_back(VPBB, cast<VPBasicBlock>(Succ));
+ VPValue *Common = GetAllEqual(OutEdges);
+ if (!Common)
+ continue;
+
+ // They have the same value: we can move the edges up
+ for (EdgeTy Edge : OutEdges)
+ Edges.erase(Edge);
+
+ // Peek through phis that are postdominated by VPBB
+ if (auto *Phi = dyn_cast<VPPhi>(Common))
+ if (VPPDT.dominates(VPBB, Phi->getParent())) {
+ for (auto [InV, InVPBB] : Phi->incoming_values_and_blocks()) {
+ AddEdge(InVPBB, Phi->getParent(), InV);
+ Worklist.insert(InVPBB);
+ }
+ continue;
+ }
+
+ // Iterate up through the post dominance frontier
+ for (const VPBlockBase *Frontier : VPPDF.find(VPBB)->second) {
+ for (const VPBlockBase *FrontierSucc : Frontier->getSuccessors())
+ if (VPPDT.dominates(VPBB, FrontierSucc))
+ AddEdge(Frontier, FrontierSucc, Common);
+ Worklist.insert(cast<VPBasicBlock>(Frontier));
+ }
+ }
+
+ return Edges;
+}
+
+VPValue *VPPredicator::createMaskDisjunction(ArrayRef<EdgeTy> Edges,
+ VPBasicBlock *VPBB) {
+ auto Dsts = map_range(Edges, [](auto E) { return E.second; });
+ const VPBasicBlock *PostDom = *Dsts.begin();
+ for (const VPBasicBlock *VPBB : drop_begin(Dsts))
+ PostDom =
+ cast<VPBasicBlock>(VPPDT.findNearestCommonDominator(PostDom, VPBB));
+ assert(VPPDT.dominates(VPBB, PostDom) && "Edges don't postdominate VPBB");
+ if (PostDom != VPBB)
+ return getBlockInMask(PostDom);
+
+ VPValue *Mask = nullptr;
+ for (auto [Src, Dst] : Edges) {
+ VPValue *EdgeMask;
+ {
+ VPBuilder::InsertPointGuard Guard(Builder);
+ Builder.setInsertPoint(const_cast<VPBasicBlock *>(Dst),
+ InsertPoints[Dst]);
+ EdgeMask = createEdgeMask(Src, Dst);
+ }
+ Mask = Mask ? Builder.createOr(Mask, EdgeMask) : EdgeMask;
+ }
+ return Mask;
+}
+
void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
+ Builder.setInsertPoint(VPBB, InsertPoints[VPBB]);
SmallVector<VPPhi *> Phis;
for (VPRecipeBase &R : VPBB->phis())
Phis.push_back(cast<VPPhi>(&R));
@@ -245,10 +373,30 @@ void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
continue;
}
+ MapVector<VPValue *, SmallVector<EdgeTy>> InValEdgesMap;
+ for (auto [Edge, Val] : computeBlendEdges(PhiR))
+ InValEdgesMap[Val].push_back(Edge);
+ auto InValEdges = InValEdgesMap.takeVector();
+
+ if (InValEdges.size() == 1) {
+ PhiR->replaceAllUsesWith(InValEdges[0].first);
+ PhiR->eraseFromParent();
+ continue;
+ }
+
+ // Sort the incoming value order to match PhiR as much as possible.
+ llvm::stable_sort(InValEdges, [&PhiR](auto &L, auto &R) {
+ auto InVs = PhiR->incoming_values();
+ return std::distance(InVs.begin(), find(InVs, L.first)) <
+ std::distance(InVs.begin(), find(InVs, R.first));
+ });
+
SmallVector<VPValue *, 2> OperandsWithMask;
- for (const auto &[InVPV, InVPBB] : PhiR->incoming_values_and_blocks()) {
+ for (const auto &[InVPV, Edges] : InValEdges) {
+ if (match(InVPV, m_Poison()))
+ continue;
OperandsWithMask.push_back(InVPV);
- OperandsWithMask.push_back(createEdgeMask(InVPBB, VPBB));
+ OperandsWithMask.push_back(createMaskDisjunction(Edges, VPBB));
}
PHINode *IRPhi = cast_or_null<PHINode>(PhiR->getUnderlyingValue());
auto *Blend =
@@ -276,10 +424,8 @@ void VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan) {
// Introduce the mask for VPBB, which may introduce needed edge masks, and
// convert all phi recipes of VPBB to blend recipes unless VPBB is the
// header.
- if (VPBB != Header) {
+ if (VPBB != Header)
Predicator.createBlockInMask(VPBB);
- Predicator.convertPhisToBlends(VPBB);
- }
VPValue *BlockMask = Predicator.getBlockInMask(VPBB);
if (!BlockMask)
@@ -292,6 +438,10 @@ void VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan) {
}
}
+ for (VPBlockBase *VPB : reverse(RPOT))
+ if (VPB != Header)
+ Predicator.convertPhisToBlends(cast<VPBasicBlock>(VPB));
+
// Linearize the blocks of the loop into one serial chain.
VPBlockBase *PrevVPBB = nullptr;
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c3867024c34dc..6c5b4c9f87b03 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -475,7 +475,6 @@ Type *llvm::computeScalarTypeForInstruction(unsigned Opcode,
return IntegerType::get(Ctx, 1);
case VPInstruction::LogicalAnd:
case VPInstruction::LogicalOr:
- case VPInstruction::MaskedCond:
assert((!Op0Ty || Op0Ty->isIntegerTy(1)) && "expected bool operand");
AssertOperandType(1, Op0Ty);
return IntegerType::get(Ctx, 1);
@@ -583,7 +582,6 @@ unsigned VPInstruction::getNumOperandsForOpcode() const {
case VPInstruction::ExtractLastLane:
case VPInstruction::ExtractLastPart:
case VPInstruction::ExtractPenultimateElement:
- case VPInstruction::MaskedCond:
case VPInstruction::Not:
case VPInstruction::ResumeForEpilogue:
case VPInstruction::Reverse:
@@ -1533,7 +1531,6 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
case VPInstruction::FirstOrderRecurrenceSplice:
case VPInstruction::LogicalAnd:
case VPInstruction::LogicalOr:
- case VPInstruction::MaskedCond:
case VPInstruction::Not:
case VPInstruction::PtrAdd:
case VPInstruction::WideIVStep:
@@ -1682,9 +1679,6 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
case VPInstruction::ExitingIVValue:
O << "exiting-iv-value";
break;
- case VPInstruction::MaskedCond:
- O << "masked-cond";
- break;
case VPInstruction::ExtractLane:
O << "extract-lane";
break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 5df097628ba7f..dca77b3a3ea5d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1669,11 +1669,6 @@ static void simplifyRecipe(VPSingleDefRecipe *Def) {
return;
}
- // Simplify MaskedCond with no block mask to its single operand.
- if (match(Def, m_VPInstruction<VPInstruction::MaskedCond>()) &&
- !cast<VPInstruction>(Def)->isMasked())
- return Def->replaceAllUsesWith(Def->getOperand(0));
-
// Look through ExtractLastLane.
if (match(Def, m_ExtractLastLane(m_VPValue(A)))) {
if (match(A, m_BuildVector())) {
@@ -2037,6 +2032,15 @@ static void simplifyBlends(VPlan &Plan) {
}
}
+ if (UniqueValues.size() == 2) {
+ for (unsigned I = 0; I != Blend->getNumIncomingValues(); ++I) {
+ if (match(Blend->getIncomingValue(I), m_False())) {
+ StartIndex = I;
+ break;
+ }
+ }
+ }
+
SmallVector<VPValue *, 4> OperandsWithMask;
OperandsWithMask.push_back(Blend->getIncomingValue(StartIndex));
@@ -4126,17 +4130,6 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
continue;
}
- // Lower MaskedCond with block mask to LogicalAnd.
- if (match(&R, m_VPInstruction<VPInstruction::MaskedCond>())) {
- auto *VPI = cast<VPInstruction>(&R);
- assert(VPI->isMasked() &&
- "Unmasked MaskedCond should be simplified earlier");
- VPI->replaceAllUsesWith(Builder.createNaryOp(
- VPInstruction::LogicalAnd, {VPI->getMask(), VPI->getOperand(0)}));
- VPI->eraseFromParent();
- continue;
- }
-
// Lower CanonicalIVIncrementForPart to plain Add.
if (match(
&R,
@@ -4206,6 +4199,7 @@ struct EarlyExitInfo {
VPBasicBlock *EarlyExitingVPBB;
VPIRBasicBlock *EarlyExitVPBB;
VPValue *CondToExit;
+ VPValue *ExitMask;
};
/// Update \p Plan to mask memory operations in the loop based on whether the
@@ -4364,10 +4358,64 @@ static bool handleUncountableExitsWithSideEffects(
return true;
}
+static VPValue *repairSSA(VPValue *Src, VPBasicBlock *SrcVPBB, VPValue *Other,
+ VPBasicBlock *VPBB, VPDominatorTree &VPDT,
+ DenseMap<VPBlockBase *, VPPhi *> &Phis) {
+
+ if (VPDT.dominates(SrcVPBB, VPBB))
+ return Src;
+ if (VPDT.dominates(VPBB, SrcVPBB))
+ return Other;
+ if (VPPhi *Phi = Phis.lookup(VPBB))
+ return Phi;
+
+ SmallVector<VPValue *> InVals;
+ for (auto *Pred : VPBB->predecessors())
+ InVals.push_back(
+ repairSSA(Src, SrcVPBB, Other, cast<VPBasicBlock>(Pred), VPDT, Phis));
+ if (all_equal(InVals))
+ return InVals[0];
+
+ VPPhi *Phi = VPBuilder(VPBB, VPBB->getFirstNonPhi()).createScalarPhi(InVals);
+ Phis[VPBB] = Phi;
+ return Phi;
+}
+
+/// Insert phi nodes to maintain SSA starting from \p VPBB, such that the
+/// resulting value is \p \Src on all paths that go through \p SrcVPBB, and \p
+/// Other otherwise.
+static VPValue *repairSSA(VPValue *Src, VPBasicBlock *SrcVPBB, VPValue *Other,
+ VPBasicBlock *VPBB, VPDominatorTree &VPDT) {
+ DenseMap<VPBlockBase *, VPPhi *> Phis;
+ return repairSSA(Src, SrcVPBB, Other, VPBB, VPDT, Phis);
+}
+
+/// Splits \p LatchVPBB so it only contains the IV increment recipes.
+static VPBasicBlock *splitLatchAtIVInc(VPBasicBlock *LatchVPBB) {
+ auto It = LatchVPBB->getTerminator()->getIterator();
+ while (!match(
+ It->getVPSingleValue(),
+ m_CombineOr(m_Add(m_Isa<VPWidenIntOrFpInductionRecipe>(), m_LiveIn()),
+ m_Sub(m_Isa<VPWidenIntOrFpInductionRecipe>(), m_LiveIn()),
+ m_VPInstruction<Instruction::GetElementPtr>(
+ m_Isa<VPWidenPointerInductionRecipe>(), m_LiveIn())))) {
+ if (It == LatchVPBB->begin())
+ return LatchVPBB;
+ It = std::prev(It);
+ }
+ LatchVPBB = LatchVPBB->splitAt(It);
+ LatchVPBB->setName("vector.latch");
+ return LatchVPBB;
+}
+
bool VPlanTransforms::handleUncountableEarlyExits(
VPlan &Plan, VPBasicBlock *HeaderVPBB, VPBasicBlock *LatchVPBB,
VPBasicBlock *MiddleVPBB, Loop *TheLoop, PredicatedScalarEvolution &PSE,
DominatorTree &DT, AssumptionCache *AC, UncountableExitStyle Style) {
+ // Split the latch at the IV increment so we can branch to it and predicate
+ // any recipes before the increment.
+ LatchVPBB = splitLatchAtIVInc(LatchVPBB);
+
VPDominatorTree VPDT(Plan);
VPBuilder LatchBuilder(LatchVPBB->getTerminator());
SmallVector<EarlyExitInfo> Exits;
@@ -4384,25 +4432,30 @@ bool VPlanTransforms::handleUncountableEarlyExits(
m_BranchOnCond(m_VPValue(CondOfEarlyExitingVPBB)));
assert(Matched && "Terminator must be BranchOnCond");
- // Insert the MaskedCond in the EarlyExitingVPBB so the predicator adds
- // the correct block mask.
VPBuilder EarlyExitingBuilder(EarlyExitingVPBB->getTerminator());
- auto *CondToEarlyExit = EarlyExitingBuilder.createNaryOp(
- VPInstruction::MaskedCond,
+ auto *CondToEarlyExit =
TrueSucc == ExitBlock
? CondOfEarlyExitingVPBB
- : EarlyExitingBuilder.createNot(CondOfEarlyExitingVPBB));
+ : EarlyExitingBuilder.createNot(CondOfEarlyExitingVPBB);
+
+ // Create the exit mask to predicate successors in other lanes.
+ // TODO: This mask doesn't get materialized yet because it's always folded
+ // away. Eventually we need to freeze it to account for the extra use.
+ VPValue *FirstExitLane =
+ EarlyExitingBuilder.createFirstActiveLane(CondToEarlyExit);
+ VPValue *ExitMask = EarlyExitingBuilder.createICmp(
+ CmpInst::ICMP_ULT,
+ EarlyExitingBuilder.createNaryOp(VPInstruction::StepVector, {},
+ FirstExitLane->getScalarType()),
+ FirstExitLane);
+
assert((isa<VPIRValue>(CondOfEarlyExitingVPBB) ||
!VPDT.properlyDominates(EarlyExitingVPBB, LatchVPBB) ||
VPDT.properlyDominates(
CondOfEarlyExitingVPBB->getDefiningRecipe()->getParent(),
LatchVPBB)) &&
"exit condition must dominate the latch");
- Exits.push_back({
- EarlyExitingVPBB,
- ExitBlock,
- CondToEarlyExit,
- });
+ Exits.push_back({EarlyExitingVPBB, ExitBlock, CondToEarlyExit, ExitMask});
}
}
@@ -4516,7 +4569,7 @@ bool VPlanTransforms::handleUncountableEarlyExits(
//
for (auto [Exit, VectorEarlyExitVPBB] :
zip_equal(Exits, VectorEarlyExitVPBBs)) {
- auto &[EarlyExitingVPBB, EarlyExitVPBB, _] = Exit;
+ auto &[EarlyExitingVPBB, EarlyExitVPBB, _, ExitMask] = Exit;
// Adjust the phi nodes in EarlyExitVPBB.
// 1. remove incoming values from EarlyExitingVPBB,
// 2. extract the incoming value at FirstActiveLane
@@ -4540,8 +4593,9 @@ bool VPlanTransforms::handleUncountableEarlyExits(
ExitIRI->addOperand(NewIncoming);
}
- EarlyExitingVPBB->getTerminator()->eraseFromParent();
+ EarlyExitingVPBB->getTerminator()->setOperand(0, ExitMask);
VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
+ VPBlockUtils::connectBlocks(EarlyExitingVPBB, LatchVPBB);
VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
}
@@ -4589,6 +4643,46 @@ bool VPlanTransforms::handleUncountableEarlyExits(
DispatchBuilder.setInsertPoint(CurrentBB);
}
+ // Repair any uses of CondToExit to preserve SSA.
+ VPDT.recalculate(Plan);
+ for (auto [I, Exit] : enumerate(Exits)) {
+ VPValue *Repaired = repairSSA(Exit.CondToExit, Exit.EarlyExitingVPBB,
+ Plan.getFalse(), LatchVPBB,...
[truncated]
|
|
@llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesStacked on #201783 If we have an early exit loop with non-dereferenceable loads after the exit, we currently bail: int z;
for (int i = 0; i < N; i++) {
if (x[i])
break;
z = y[i];
}If the early exit block dominates the block containing these loads, we could predicate these loads. This PR helps prepare for this by modelling the act of taking an early exit in the VPlan CFG, so that blocks can be automatically predicated by VPlanPredicator. The idea is similar to how tail folding was modelled explicitly in the VPlan CFG in #176143, except we branch to a minimal latch from the exiting block on an exit mask flowchart TD
header["header"]
exiting["exiting"]
predicated["predicated"]
latch["latch"]
header --> exiting
exiting -->| exitmask | predicated
predicated --> latch
exiting -->| !exitmask | latch
This approach works fine for non-predicated early exits, but for early exits with predication VPInstruction::MaskedCond ends up producing redundant masks stemming from the extra predication at each exiting block. The solution to this is to preserve SSA form by inserting phi nodes where needed. With the improvements to VPPredicator the exit mask folds away entirely. Maintaining SSA form also has the benefit that we can remove a lot of the side-steps for VPInstruction::MaskedCond in VPlanVerifier.cpp. The exit mask never gets materialized in this PR since legality checks prevent us from needing to attach the mask to any instruction. This can be enabled afterwards in another PR. Patch is 70.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/201784.diff 18 Files Affected:
diff --git a/llvm/include/llvm/Analysis/DominanceFrontier.h b/llvm/include/llvm/Analysis/DominanceFrontier.h
index fd38891e901e3..4a8ab96cf71a7 100644
--- a/llvm/include/llvm/Analysis/DominanceFrontier.h
+++ b/llvm/include/llvm/Analysis/DominanceFrontier.h
@@ -78,6 +78,7 @@ class DominanceFrontierBase {
const_iterator end() const { return Frontiers.end(); }
iterator find(BlockT *B) { return Frontiers.find(B); }
const_iterator find(BlockT *B) const { return Frontiers.find(B); }
+ const_iterator find(const BlockT *B) const { return Frontiers.find(B); }
/// print - Convert to human readable form
///
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b740665fe70bf..8d6be4aa3e465 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1380,9 +1380,6 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
/// Returns true if the VPInstruction does not need masking.
bool alwaysUnmasked() const {
- if (Opcode == VPInstruction::MaskedCond)
- return false;
-
// For now only VPInstructions with underlying values use masks.
// TODO: provide masks to VPInstructions w/o underlying values.
if (!getUnderlyingValue())
diff --git a/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h b/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
index 2864670f44913..1ad522880c709 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
@@ -18,6 +18,8 @@
#include "VPlan.h"
#include "VPlanCFG.h"
#include "llvm/ADT/GraphTraits.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/DominanceFrontierImpl.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Support/GenericDomTree.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
@@ -67,5 +69,11 @@ template <>
struct GraphTraits<const VPDomTreeNode *>
: public DomTreeGraphTraitsBase<const VPDomTreeNode,
VPDomTreeNode::const_iterator> {};
+
+class VPPostDominanceFrontier
+ : public DominanceFrontierBase<VPBlockBase, true> {
+public:
+ explicit VPPostDominanceFrontier(const DomTreeT &VPDT) { analyze(VPDT); }
+};
} // namespace llvm
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANDOMINATORTREE_H
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp b/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
index 2717b80e2eeaa..2ec3df8ccf8c1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
@@ -34,6 +34,9 @@ class VPPredicator {
/// Post-dominator tree for the VPlan.
VPPostDominatorTree VPPDT;
+ /// Post-dominator frontier for the VPlan.
+ VPPostDominanceFrontier VPPDF;
+
/// When we if-convert we need to create edge masks. We have to cache values
/// so that we don't end up with exponential recursion/IR.
using EdgeMaskCacheTy =
@@ -69,8 +72,19 @@ class VPPredicator {
return EdgeMaskCache[{Src, Dst}] = Mask;
}
+ using EdgeTy = std::pair<const VPBasicBlock *, const VPBasicBlock *>;
+
+ /// Compute the "furthest up" set of edges for each incoming value of \Phi.
+ MapVector<EdgeTy, VPValue *> computeBlendEdges(VPPhi *Phi);
+
+ /// Given a set of \p Edges that lead to \p VPBB, return the OR of all edges
+ /// or an equivalent block in-mask.
+ VPValue *createMaskDisjunction(ArrayRef<EdgeTy> Edges, VPBasicBlock *VPBB);
+
+ DenseMap<const VPBasicBlock *, VPBasicBlock::iterator> InsertPoints;
+
public:
- VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan) {}
+ VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan), VPPDF(VPPDT) {}
/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(const VPBasicBlock *VPBB) const {
@@ -136,6 +150,10 @@ void VPPredicator::createBlockInMask(VPBasicBlock *VPBB) {
// Start inserting after the block's phis, which be replaced by blends later.
Builder.setInsertPoint(VPBB, VPBB->getFirstNonPhi());
+ // Keep track of where in VPBB we are inserting the masks into.
+ scope_exit UpdateInsertPoint(
+ [this, &VPBB]() { InsertPoints[VPBB] = Builder.getInsertPoint(); });
+
// Reuse the mask of the immediate dominator if the VPBB post-dominates the
// immediate dominator.
auto *IDom = VPDT.getNode(VPBB)->getIDom();
@@ -224,7 +242,117 @@ void VPPredicator::createSwitchEdgeMasks(const VPInstruction *SI) {
setEdgeMask(Src, DefaultDst, DefaultMask);
}
+// Compute the "furthest up" set of edges for each incoming value of a phi.
+//
+// Start by keeping track of what edges lead to which value. Then see if any
+// node has the same value for all outgoing edges. If so then propagate that
+// value up to every node it postdominates.
+MapVector<VPPredicator::EdgeTy, VPValue *>
+VPPredicator::computeBlendEdges(VPPhi *Phi) {
+ MapVector<EdgeTy, VPValue *> Edges;
+
+ // Mark the given edge as providing the value \p V.
+ auto AddEdge = [&Edges](const VPBlockBase *From, const VPBlockBase *To,
+ VPValue *V) {
+ EdgeTy Edge = {cast<VPBasicBlock>(From), cast<VPBasicBlock>(To)};
+ assert((!Edges.contains(Edge) || Edges.lookup(Edge) == V) &&
+ "Clobbering an edge?");
+ Edges[Edge] = V;
+ };
+
+ for (auto [InVal, InVPBB] : Phi->incoming_values_and_blocks())
+ AddEdge(InVPBB, Phi->getParent(), InVal);
+
+ // The root phi must postdominate every incoming block. Also don't touch
+ // phis in a reduction chain since they need to be in a specific structure
+ // for handle*Reductions.
+ for (auto [InVal, InVPBB] : Phi->incoming_values_and_blocks())
+ if (!VPPDT.dominates(Phi->getParent(), InVPBB) ||
+ isa<VPReductionPHIRecipe>(InVal))
+ return Edges;
+
+ // Given a list of edges, check if they all have the same value and return it.
+ auto GetAllEqual = [&Edges](ArrayRef<EdgeTy> OutEdges) -> VPValue * {
+ VPValue *Common = nullptr;
+ for (EdgeTy E : OutEdges) {
+ VPValue *V = Edges.lookup(E);
+ if (!V)
+ return nullptr;
+ if (match(V, m_Poison()))
+ continue;
+ if (!Common)
+ Common = V;
+ else if (Common != V)
+ return nullptr;
+ }
+ return Common;
+ };
+
+ SetVector<const VPBlockBase *> Worklist(from_range, Phi->incoming_blocks());
+ while (!Worklist.empty()) {
+ auto *VPBB = cast<VPBasicBlock>(Worklist.pop_back_val());
+
+ // Check that all outgoing edges from VPBB have the same value.
+ SmallVector<EdgeTy> OutEdges;
+ for (const VPBlockBase *Succ : VPBB->getSuccessors())
+ OutEdges.emplace_back(VPBB, cast<VPBasicBlock>(Succ));
+ VPValue *Common = GetAllEqual(OutEdges);
+ if (!Common)
+ continue;
+
+ // They have the same value: we can move the edges up
+ for (EdgeTy Edge : OutEdges)
+ Edges.erase(Edge);
+
+ // Peek through phis that are postdominated by VPBB
+ if (auto *Phi = dyn_cast<VPPhi>(Common))
+ if (VPPDT.dominates(VPBB, Phi->getParent())) {
+ for (auto [InV, InVPBB] : Phi->incoming_values_and_blocks()) {
+ AddEdge(InVPBB, Phi->getParent(), InV);
+ Worklist.insert(InVPBB);
+ }
+ continue;
+ }
+
+ // Iterate up through the post dominance frontier
+ for (const VPBlockBase *Frontier : VPPDF.find(VPBB)->second) {
+ for (const VPBlockBase *FrontierSucc : Frontier->getSuccessors())
+ if (VPPDT.dominates(VPBB, FrontierSucc))
+ AddEdge(Frontier, FrontierSucc, Common);
+ Worklist.insert(cast<VPBasicBlock>(Frontier));
+ }
+ }
+
+ return Edges;
+}
+
+VPValue *VPPredicator::createMaskDisjunction(ArrayRef<EdgeTy> Edges,
+ VPBasicBlock *VPBB) {
+ auto Dsts = map_range(Edges, [](auto E) { return E.second; });
+ const VPBasicBlock *PostDom = *Dsts.begin();
+ for (const VPBasicBlock *VPBB : drop_begin(Dsts))
+ PostDom =
+ cast<VPBasicBlock>(VPPDT.findNearestCommonDominator(PostDom, VPBB));
+ assert(VPPDT.dominates(VPBB, PostDom) && "Edges don't postdominate VPBB");
+ if (PostDom != VPBB)
+ return getBlockInMask(PostDom);
+
+ VPValue *Mask = nullptr;
+ for (auto [Src, Dst] : Edges) {
+ VPValue *EdgeMask;
+ {
+ VPBuilder::InsertPointGuard Guard(Builder);
+ Builder.setInsertPoint(const_cast<VPBasicBlock *>(Dst),
+ InsertPoints[Dst]);
+ EdgeMask = createEdgeMask(Src, Dst);
+ }
+ Mask = Mask ? Builder.createOr(Mask, EdgeMask) : EdgeMask;
+ }
+ return Mask;
+}
+
void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
+ Builder.setInsertPoint(VPBB, InsertPoints[VPBB]);
SmallVector<VPPhi *> Phis;
for (VPRecipeBase &R : VPBB->phis())
Phis.push_back(cast<VPPhi>(&R));
@@ -245,10 +373,30 @@ void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
continue;
}
+ MapVector<VPValue *, SmallVector<EdgeTy>> InValEdgesMap;
+ for (auto [Edge, Val] : computeBlendEdges(PhiR))
+ InValEdgesMap[Val].push_back(Edge);
+ auto InValEdges = InValEdgesMap.takeVector();
+
+ if (InValEdges.size() == 1) {
+ PhiR->replaceAllUsesWith(InValEdges[0].first);
+ PhiR->eraseFromParent();
+ continue;
+ }
+
+ // Sort the incoming value order to match PhiR as much as possible.
+ llvm::stable_sort(InValEdges, [&PhiR](auto &L, auto &R) {
+ auto InVs = PhiR->incoming_values();
+ return std::distance(InVs.begin(), find(InVs, L.first)) <
+ std::distance(InVs.begin(), find(InVs, R.first));
+ });
+
SmallVector<VPValue *, 2> OperandsWithMask;
- for (const auto &[InVPV, InVPBB] : PhiR->incoming_values_and_blocks()) {
+ for (const auto &[InVPV, Edges] : InValEdges) {
+ if (match(InVPV, m_Poison()))
+ continue;
OperandsWithMask.push_back(InVPV);
- OperandsWithMask.push_back(createEdgeMask(InVPBB, VPBB));
+ OperandsWithMask.push_back(createMaskDisjunction(Edges, VPBB));
}
PHINode *IRPhi = cast_or_null<PHINode>(PhiR->getUnderlyingValue());
auto *Blend =
@@ -276,10 +424,8 @@ void VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan) {
// Introduce the mask for VPBB, which may introduce needed edge masks, and
// convert all phi recipes of VPBB to blend recipes unless VPBB is the
// header.
- if (VPBB != Header) {
+ if (VPBB != Header)
Predicator.createBlockInMask(VPBB);
- Predicator.convertPhisToBlends(VPBB);
- }
VPValue *BlockMask = Predicator.getBlockInMask(VPBB);
if (!BlockMask)
@@ -292,6 +438,10 @@ void VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan) {
}
}
+ for (VPBlockBase *VPB : reverse(RPOT))
+ if (VPB != Header)
+ Predicator.convertPhisToBlends(cast<VPBasicBlock>(VPB));
+
// Linearize the blocks of the loop into one serial chain.
VPBlockBase *PrevVPBB = nullptr;
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index c3867024c34dc..6c5b4c9f87b03 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -475,7 +475,6 @@ Type *llvm::computeScalarTypeForInstruction(unsigned Opcode,
return IntegerType::get(Ctx, 1);
case VPInstruction::LogicalAnd:
case VPInstruction::LogicalOr:
- case VPInstruction::MaskedCond:
assert((!Op0Ty || Op0Ty->isIntegerTy(1)) && "expected bool operand");
AssertOperandType(1, Op0Ty);
return IntegerType::get(Ctx, 1);
@@ -583,7 +582,6 @@ unsigned VPInstruction::getNumOperandsForOpcode() const {
case VPInstruction::ExtractLastLane:
case VPInstruction::ExtractLastPart:
case VPInstruction::ExtractPenultimateElement:
- case VPInstruction::MaskedCond:
case VPInstruction::Not:
case VPInstruction::ResumeForEpilogue:
case VPInstruction::Reverse:
@@ -1533,7 +1531,6 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
case VPInstruction::FirstOrderRecurrenceSplice:
case VPInstruction::LogicalAnd:
case VPInstruction::LogicalOr:
- case VPInstruction::MaskedCond:
case VPInstruction::Not:
case VPInstruction::PtrAdd:
case VPInstruction::WideIVStep:
@@ -1682,9 +1679,6 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
case VPInstruction::ExitingIVValue:
O << "exiting-iv-value";
break;
- case VPInstruction::MaskedCond:
- O << "masked-cond";
- break;
case VPInstruction::ExtractLane:
O << "extract-lane";
break;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 5df097628ba7f..dca77b3a3ea5d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1669,11 +1669,6 @@ static void simplifyRecipe(VPSingleDefRecipe *Def) {
return;
}
- // Simplify MaskedCond with no block mask to its single operand.
- if (match(Def, m_VPInstruction<VPInstruction::MaskedCond>()) &&
- !cast<VPInstruction>(Def)->isMasked())
- return Def->replaceAllUsesWith(Def->getOperand(0));
-
// Look through ExtractLastLane.
if (match(Def, m_ExtractLastLane(m_VPValue(A)))) {
if (match(A, m_BuildVector())) {
@@ -2037,6 +2032,15 @@ static void simplifyBlends(VPlan &Plan) {
}
}
+ if (UniqueValues.size() == 2) {
+ for (unsigned I = 0; I != Blend->getNumIncomingValues(); ++I) {
+ if (match(Blend->getIncomingValue(I), m_False())) {
+ StartIndex = I;
+ break;
+ }
+ }
+ }
+
SmallVector<VPValue *, 4> OperandsWithMask;
OperandsWithMask.push_back(Blend->getIncomingValue(StartIndex));
@@ -4126,17 +4130,6 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
continue;
}
- // Lower MaskedCond with block mask to LogicalAnd.
- if (match(&R, m_VPInstruction<VPInstruction::MaskedCond>())) {
- auto *VPI = cast<VPInstruction>(&R);
- assert(VPI->isMasked() &&
- "Unmasked MaskedCond should be simplified earlier");
- VPI->replaceAllUsesWith(Builder.createNaryOp(
- VPInstruction::LogicalAnd, {VPI->getMask(), VPI->getOperand(0)}));
- VPI->eraseFromParent();
- continue;
- }
-
// Lower CanonicalIVIncrementForPart to plain Add.
if (match(
&R,
@@ -4206,6 +4199,7 @@ struct EarlyExitInfo {
VPBasicBlock *EarlyExitingVPBB;
VPIRBasicBlock *EarlyExitVPBB;
VPValue *CondToExit;
+ VPValue *ExitMask;
};
/// Update \p Plan to mask memory operations in the loop based on whether the
@@ -4364,10 +4358,64 @@ static bool handleUncountableExitsWithSideEffects(
return true;
}
+static VPValue *repairSSA(VPValue *Src, VPBasicBlock *SrcVPBB, VPValue *Other,
+ VPBasicBlock *VPBB, VPDominatorTree &VPDT,
+ DenseMap<VPBlockBase *, VPPhi *> &Phis) {
+
+ if (VPDT.dominates(SrcVPBB, VPBB))
+ return Src;
+ if (VPDT.dominates(VPBB, SrcVPBB))
+ return Other;
+ if (VPPhi *Phi = Phis.lookup(VPBB))
+ return Phi;
+
+ SmallVector<VPValue *> InVals;
+ for (auto *Pred : VPBB->predecessors())
+ InVals.push_back(
+ repairSSA(Src, SrcVPBB, Other, cast<VPBasicBlock>(Pred), VPDT, Phis));
+ if (all_equal(InVals))
+ return InVals[0];
+
+ VPPhi *Phi = VPBuilder(VPBB, VPBB->getFirstNonPhi()).createScalarPhi(InVals);
+ Phis[VPBB] = Phi;
+ return Phi;
+}
+
+/// Insert phi nodes to maintain SSA starting from \p VPBB, such that the
+/// resulting value is \p \Src on all paths that go through \p SrcVPBB, and \p
+/// Other otherwise.
+static VPValue *repairSSA(VPValue *Src, VPBasicBlock *SrcVPBB, VPValue *Other,
+ VPBasicBlock *VPBB, VPDominatorTree &VPDT) {
+ DenseMap<VPBlockBase *, VPPhi *> Phis;
+ return repairSSA(Src, SrcVPBB, Other, VPBB, VPDT, Phis);
+}
+
+/// Splits \p LatchVPBB so it only contains the IV increment recipes.
+static VPBasicBlock *splitLatchAtIVInc(VPBasicBlock *LatchVPBB) {
+ auto It = LatchVPBB->getTerminator()->getIterator();
+ while (!match(
+ It->getVPSingleValue(),
+ m_CombineOr(m_Add(m_Isa<VPWidenIntOrFpInductionRecipe>(), m_LiveIn()),
+ m_Sub(m_Isa<VPWidenIntOrFpInductionRecipe>(), m_LiveIn()),
+ m_VPInstruction<Instruction::GetElementPtr>(
+ m_Isa<VPWidenPointerInductionRecipe>(), m_LiveIn())))) {
+ if (It == LatchVPBB->begin())
+ return LatchVPBB;
+ It = std::prev(It);
+ }
+ LatchVPBB = LatchVPBB->splitAt(It);
+ LatchVPBB->setName("vector.latch");
+ return LatchVPBB;
+}
+
bool VPlanTransforms::handleUncountableEarlyExits(
VPlan &Plan, VPBasicBlock *HeaderVPBB, VPBasicBlock *LatchVPBB,
VPBasicBlock *MiddleVPBB, Loop *TheLoop, PredicatedScalarEvolution &PSE,
DominatorTree &DT, AssumptionCache *AC, UncountableExitStyle Style) {
+ // Split the latch at the IV increment so we can branch to it and predicate
+ // any recipes before the increment.
+ LatchVPBB = splitLatchAtIVInc(LatchVPBB);
+
VPDominatorTree VPDT(Plan);
VPBuilder LatchBuilder(LatchVPBB->getTerminator());
SmallVector<EarlyExitInfo> Exits;
@@ -4384,25 +4432,30 @@ bool VPlanTransforms::handleUncountableEarlyExits(
m_BranchOnCond(m_VPValue(CondOfEarlyExitingVPBB)));
assert(Matched && "Terminator must be BranchOnCond");
- // Insert the MaskedCond in the EarlyExitingVPBB so the predicator adds
- // the correct block mask.
VPBuilder EarlyExitingBuilder(EarlyExitingVPBB->getTerminator());
- auto *CondToEarlyExit = EarlyExitingBuilder.createNaryOp(
- VPInstruction::MaskedCond,
+ auto *CondToEarlyExit =
TrueSucc == ExitBlock
? CondOfEarlyExitingVPBB
- : EarlyExitingBuilder.createNot(CondOfEarlyExitingVPBB));
+ : EarlyExitingBuilder.createNot(CondOfEarlyExitingVPBB);
+
+ // Create the exit mask to predicate successors in other lanes.
+ // TODO: This mask doesn't get materialized yet because it's always folded
+ // away. Eventually we need to freeze it to account for the extra use.
+ VPValue *FirstExitLane =
+ EarlyExitingBuilder.createFirstActiveLane(CondToEarlyExit);
+ VPValue *ExitMask = EarlyExitingBuilder.createICmp(
+ CmpInst::ICMP_ULT,
+ EarlyExitingBuilder.createNaryOp(VPInstruction::StepVector, {},
+ FirstExitLane->getScalarType()),
+ FirstExitLane);
+
assert((isa<VPIRValue>(CondOfEarlyExitingVPBB) ||
!VPDT.properlyDominates(EarlyExitingVPBB, LatchVPBB) ||
VPDT.properlyDominates(
CondOfEarlyExitingVPBB->getDefiningRecipe()->getParent(),
LatchVPBB)) &&
"exit condition must dominate the latch");
- Exits.push_back({
- EarlyExitingVPBB,
- ExitBlock,
- CondToEarlyExit,
- });
+ Exits.push_back({EarlyExitingVPBB, ExitBlock, CondToEarlyExit, ExitMask});
}
}
@@ -4516,7 +4569,7 @@ bool VPlanTransforms::handleUncountableEarlyExits(
//
for (auto [Exit, VectorEarlyExitVPBB] :
zip_equal(Exits, VectorEarlyExitVPBBs)) {
- auto &[EarlyExitingVPBB, EarlyExitVPBB, _] = Exit;
+ auto &[EarlyExitingVPBB, EarlyExitVPBB, _, ExitMask] = Exit;
// Adjust the phi nodes in EarlyExitVPBB.
// 1. remove incoming values from EarlyExitingVPBB,
// 2. extract the incoming value at FirstActiveLane
@@ -4540,8 +4593,9 @@ bool VPlanTransforms::handleUncountableEarlyExits(
ExitIRI->addOperand(NewIncoming);
}
- EarlyExitingVPBB->getTerminator()->eraseFromParent();
+ EarlyExitingVPBB->getTerminator()->setOperand(0, ExitMask);
VPBlockUtils::disconnectBlocks(EarlyExitingVPBB, EarlyExitVPBB);
+ VPBlockUtils::connectBlocks(EarlyExitingVPBB, LatchVPBB);
VPBlockUtils::connectBlocks(VectorEarlyExitVPBB, EarlyExitVPBB);
}
@@ -4589,6 +4643,46 @@ bool VPlanTransforms::handleUncountableEarlyExits(
DispatchBuilder.setInsertPoint(CurrentBB);
}
+ // Repair any uses of CondToExit to preserve SSA.
+ VPDT.recalculate(Plan);
+ for (auto [I, Exit] : enumerate(Exits)) {
+ VPValue *Repaired = repairSSA(Exit.CondToExit, Exit.EarlyExitingVPBB,
+ Plan.getFalse(), LatchVPBB,...
[truncated]
|
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
7fa5b89 to
f4e2868
Compare
f4e2868 to
20f97bd
Compare
| return true; | ||
| } | ||
|
|
||
| static VPValue *repairSSA(VPValue *Src, VPBasicBlock *SrcVPBB, VPValue *Other, |
There was a problem hiding this comment.
Rename to repairSSAImpl so that it would be clear why no comment here?
| /// Insert phi nodes to maintain SSA starting from \p VPBB, such that the | ||
| /// resulting value is \p \Src on all paths that go through \p SrcVPBB, and \p | ||
| /// Other otherwise. |
There was a problem hiding this comment.
I don't see any IDF usage, so I guess it assumes that SSA is only being broken by the early exits handling (i.e., in a very specific/restricted way)? If so, document that?
There was a problem hiding this comment.
Yeah, this isn't "SSA broken by adding more definitions" but "SSA broken by adding more control flow". So I don't think the IDF helps here. I couldn't think of any immediately obvious way in which the IPDF would help either. Will document.
We could in theory reuse this for foldTailByMasking but it's overkill since we only ever need to insert a single phi there.
There was a problem hiding this comment.
What I meant is that we have a very specific pattern how edges are removed/added, so we know where new phis are possibly required.
| } | ||
|
|
||
| /// Splits \p LatchVPBB so it only contains the IV increment recipes. | ||
| static VPBasicBlock *splitLatchAtIVInc(VPBasicBlock *LatchVPBB) { |
There was a problem hiding this comment.
Is this similar to
llvm-project/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
Lines 1381 to 1392 in 52ba9fa
? Why don't we have a similar assert to make sure we don't leave anything unexpected in the "new" latch?
There was a problem hiding this comment.
I've removed the latch splitting and branching part of this PR and split it off into #203263. This PR is now focused exclusively on maintaining SSA and replacing MaskedCond
We don't need it for now
20f97bd to
42e1f59
Compare
* Reuse previous method in DomiananceFrontier * Replace GetAllEqual with a map_range
After thinking about this for a bit this isn't needed. If a phi doesn't postdominate an incoming block, the incoming block will have an outgoing edge with no value. So we won't propagate any further up that incoming block anyway. What differs between this approach and llvm#184838 is that the latter performs a full inverse DFS to see what blocks are reachable, whereas this just checks that the incoming values are the same at each postdominance frontier. The test case phi_doesnt_postdom_incoming shows a scenario where the full inverse DFS approach could simplify the edge to just c1 and !c1, but we calculate the conservative (but still correct) edges in this PR.
42e1f59 to
296386b
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
296386b to
4513c88
Compare
* Explain when to use repairSSA * repairSSA -> repairSSAImpl
4513c88 to
74457c7
Compare
Stacked on #203164
If we want to allow tail folding with early exit loops, we will need to combine both a header mask and an early exit mask. To do so a subsequent PR will implement masking via the CFG, so that it can be composed with the tail folding transformation, e.g.:
However with the CFG approach we end up with redundant masks produced by MaskedCond in multi-early-exit loops, stemming from the extra predication in the previous early exits.
Early exit vectorization today also has the issue where predicated exiting blocks do not preserve SSA, as CondToExit does not dominate their use in the AnyOf in the latch.
This PR fixes this by replacing MaskedCond with explicit phi nodes in the CFG, where the incoming value is false on paths that didn't pass through the exiting block. It also repairs any live outs that were defined in an exiting block. This allows any exit masks from prior exiting blocks to be folded away when they're added in a subsequent PR.
This also has the advantage of maintaining SSA form, so we no longer need the VPlanVerifier bypasses for MaskedCond.