Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/Analysis/DominanceFrontier.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DominanceFrontierBase {
iterator end() { return Frontiers.end(); }
const_iterator end() const { return Frontiers.end(); }
iterator find(BlockT *B) { return Frontiers.find(B); }
const_iterator find(BlockT *B) const { return Frontiers.find(B); }
const_iterator find(const BlockT *B) const { return Frontiers.find(B); }

/// print - Convert to human readable form
///
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1381,9 +1381,6 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,

/// Returns true if the VPInstruction does not need masking.
bool alwaysUnmasked() const {
if (Opcode == VPInstruction::MaskedCond)
return false;

// For now only VPInstructions with underlying values use masks.
// TODO: provide masks to VPInstructions w/o underlying values.
if (!getUnderlyingValue())
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "VPlan.h"
#include "VPlanCFG.h"
#include "llvm/ADT/GraphTraits.h"
#include "llvm/Analysis/DominanceFrontier.h"
#include "llvm/Analysis/DominanceFrontierImpl.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Support/GenericDomTree.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
Expand Down Expand Up @@ -67,5 +69,11 @@ template <>
struct GraphTraits<const VPDomTreeNode *>
: public DomTreeGraphTraitsBase<const VPDomTreeNode,
VPDomTreeNode::const_iterator> {};

class VPPostDominanceFrontier
: public DominanceFrontierBase<VPBlockBase, true> {
public:
explicit VPPostDominanceFrontier(const DomTreeT &VPDT) { analyze(VPDT); }
};
} // namespace llvm
#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANDOMINATORTREE_H
148 changes: 145 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class VPPredicator {
/// Post-dominator tree for the VPlan.
VPPostDominatorTree VPPDT;

/// Post-dominator frontier for the VPlan.
VPPostDominanceFrontier VPPDF;

/// When we if-convert we need to create edge masks. We have to cache values
/// so that we don't end up with exponential recursion/IR.
using EdgeMaskCacheTy =
Expand Down Expand Up @@ -78,8 +81,17 @@ class VPPredicator {
return VPBB->getFirstNonPhi();
}

using EdgeTy = std::pair<const VPBasicBlock *, const VPBasicBlock *>;

/// Compute the "furthest up" set of edges for each incoming value of \Phi.
MapVector<EdgeTy, VPValue *> computeBlendEdges(VPPhi *Phi);

/// Given a set of \p Edges that each can reach \p VPBB, return the OR of all
/// edges, or an equivalent block in-mask.
VPValue *createBlendMaskForEdges(ArrayRef<EdgeTy> Edges, VPBasicBlock *VPBB);

public:
VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan) {}
VPPredicator(VPlan &Plan) : VPDT(Plan), VPPDT(Plan), VPPDF(VPPDT) {}

/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(const VPBasicBlock *VPBB) const {
Expand Down Expand Up @@ -233,6 +245,118 @@ void VPPredicator::createSwitchEdgeMasks(const VPInstruction *SI) {
setEdgeMask(Src, DefaultDst, DefaultMask);
}

// Start by keeping track of what edges lead to which value. Then see if any
// node has the same value for all outgoing edges. If so then propagate that
// value up to every node it postdominates. E.g:
//
// Entry Edges = {C->ɸ : %x, D->ɸ : %x, F->ɸ : %y}
// / \ [C,D,F all outgoing edges equal: go up postdom frontier]
// A B ~> {A->C : %x, A->D : %x, Entry->B : %y}
// / \ |\ [A all outgoing edges equal: go up postdom frontier]
// C D | E ~> {Entry->A : %x, Entry->B : %y}
// \ \ |/
// \ | F
// \ | /
// ɸ = phi [%x, C], [%x, D], [%y, F]
MapVector<VPPredicator::EdgeTy, VPValue *>
VPPredicator::computeBlendEdges(VPPhi *Phi) {
MapVector<EdgeTy, VPValue *> Edges;

// Mark the given edge as providing the value \p V.
auto AddEdge = [&Edges](const VPBlockBase *From, const VPBlockBase *To,
VPValue *V) {
EdgeTy Edge = {cast<VPBasicBlock>(From), cast<VPBasicBlock>(To)};
assert((!Edges.contains(Edge) || Edges.lookup(Edge) == V) &&
"Clobbering an edge?");
Edges[Edge] = V;
};

for (auto [InVal, InVPBB] : Phi->incoming_values_and_blocks())
AddEdge(InVPBB, Phi->getParent(), InVal);

// Don't optimize any reduction chains for now.
if (any_of(Phi->incoming_values(), IsaPred<VPReductionPHIRecipe>))
return Edges;

SetVector<const VPBlockBase *> Worklist(from_range, Phi->incoming_blocks());
while (!Worklist.empty()) {
auto *VPBB = cast<VPBasicBlock>(Worklist.pop_back_val());

// Check that all outgoing edges from VPBB have the same value.
SmallVector<EdgeTy> OutEdges;
for (const VPBlockBase *Succ : VPBB->getSuccessors())
OutEdges.emplace_back(VPBB, cast<VPBasicBlock>(Succ));
auto OutVals =
map_range(OutEdges, [&Edges](EdgeTy E) { return Edges.lookup(E); });
VPValue *Common = *OutVals.begin();
if (!Common || !all_equal(OutVals))
continue;

// They have the same value: we can move the edges up.
for (EdgeTy Edge : OutEdges)
Edges.erase(Edge);

// If the value is a phi postdominated by VPBB, then look through the inner
// incoming values instead of propagating the phi.
if (auto *Phi = dyn_cast<VPPhi>(Common))
if (Phi->hasOneUse() && VPPDT.dominates(VPBB, Phi->getParent())) {
for (auto [InV, InVPBB] : Phi->incoming_values_and_blocks()) {
AddEdge(InVPBB, Phi->getParent(), InV);
Worklist.insert(InVPBB);
}
continue;
}

// Iterate up through the post dominance frontier.
for (const VPBlockBase *Frontier : VPPDF.find(VPBB)->second) {
for (const VPBlockBase *FrontierSucc : Frontier->getSuccessors())
if (VPPDT.dominates(VPBB, FrontierSucc))
AddEdge(Frontier, FrontierSucc, Common);
Worklist.insert(cast<VPBasicBlock>(Frontier));
}
}

return Edges;
}

VPValue *VPPredicator::createBlendMaskForEdges(ArrayRef<EdgeTy> Edges,
VPBasicBlock *VPBB) {
// If the nearest common postdominator to all of Edges destinations isn't VPBB
// then we can use its block in-mask. E.g:
//
// A ... B
// \ \ /
// \ C
// \ /
// ... D ...
// \ | /
// VPBB
//
// If the edges are A->D and B->C, PostDom will be D. We can reuse Ds block
// in-mask.
const VPBasicBlock *PostDom = Edges[0].second;
for (auto [_, VPBB] : drop_begin(Edges))
PostDom =
cast<VPBasicBlock>(VPPDT.findNearestCommonDominator(PostDom, VPBB));
assert(VPPDT.dominates(VPBB, PostDom) && "VPBB doesn't postdominate edges");
if (PostDom != VPBB)
return getBlockInMask(PostDom);

// Otherwise, compute the disjunction of edges.
VPValue *Mask = nullptr;
for (auto [Src, ConstDst] : Edges) {
auto *Dst = const_cast<VPBasicBlock *>(ConstDst);
VPValue *EdgeMask;
{
VPBuilder::InsertPointGuard Guard(Builder);
Builder.setInsertPoint(Dst, getMaskInsertPoint(Dst));
EdgeMask = createEdgeMask(Src, Dst);
}
Mask = Mask ? Builder.createOr(Mask, EdgeMask) : EdgeMask;
}
return Mask;
}

void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
Builder.setInsertPoint(VPBB, getMaskInsertPoint(VPBB));

Expand All @@ -256,10 +380,28 @@ void VPPredicator::convertPhisToBlends(VPBasicBlock *VPBB) {
continue;
}

MapVector<VPValue *, SmallVector<EdgeTy>> InValEdgesMap;
for (auto [Edge, Val] : computeBlendEdges(PhiR))
InValEdgesMap[Val].push_back(Edge);
auto InValEdges = InValEdgesMap.takeVector();

if (InValEdges.size() == 1) {
PhiR->replaceAllUsesWith(InValEdges[0].first);
PhiR->eraseFromParent();
continue;
}

// Sort the incoming value order to match PhiR as much as possible.
llvm::stable_sort(InValEdges, [&PhiR](auto &L, auto &R) {
auto InVs = PhiR->incoming_values();
return std::distance(InVs.begin(), find(InVs, L.first)) <
std::distance(InVs.begin(), find(InVs, R.first));
});

SmallVector<VPValue *, 2> OperandsWithMask;
for (const auto &[InVPV, InVPBB] : PhiR->incoming_values_and_blocks()) {
for (const auto &[InVPV, Edges] : InValEdges) {
OperandsWithMask.push_back(InVPV);
OperandsWithMask.push_back(createEdgeMask(InVPBB, VPBB));
OperandsWithMask.push_back(createBlendMaskForEdges(Edges, VPBB));
}
PHINode *IRPhi = cast_or_null<PHINode>(PhiR->getUnderlyingValue());
auto *Blend =
Expand Down
5 changes: 0 additions & 5 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ unsigned VPInstruction::getNumOperandsForOpcode() const {
case VPInstruction::ExtractLastLane:
case VPInstruction::ExtractLastPart:
case VPInstruction::ExtractPenultimateElement:
case VPInstruction::MaskedCond:
case VPInstruction::Not:
case VPInstruction::Reverse:
case VPInstruction::Unpack:
Expand Down Expand Up @@ -1630,7 +1629,6 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
case VPInstruction::FirstOrderRecurrenceSplice:
case VPInstruction::LogicalAnd:
case VPInstruction::LogicalOr:
case VPInstruction::MaskedCond:
case VPInstruction::Not:
case VPInstruction::PtrAdd:
case VPInstruction::WideIVStep:
Expand Down Expand Up @@ -1780,9 +1778,6 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
case VPInstruction::ExitingIVValue:
O << "exiting-iv-value";
break;
case VPInstruction::MaskedCond:
O << "masked-cond";
break;
case VPInstruction::ExtractLane:
O << "extract-lane";
break;
Expand Down
99 changes: 78 additions & 21 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1690,11 +1690,6 @@ static void simplifyRecipe(VPSingleDefRecipe *Def) {
return;
}

// Simplify MaskedCond with no block mask to its single operand.
if (match(Def, m_VPInstruction<VPInstruction::MaskedCond>()) &&
!cast<VPInstruction>(Def)->isMasked())
return Def->replaceAllUsesWith(Def->getOperand(0));

// Look through ExtractLastLane.
if (match(Def, m_ExtractLastLane(m_VPValue(A)))) {
if (match(A, m_BuildVector())) {
Expand Down Expand Up @@ -4215,17 +4210,6 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
continue;
}

// Lower MaskedCond with block mask to LogicalAnd.
if (match(&R, m_VPInstruction<VPInstruction::MaskedCond>())) {
auto *VPI = cast<VPInstruction>(&R);
assert(VPI->isMasked() &&
"Unmasked MaskedCond should be simplified earlier");
VPI->replaceAllUsesWith(Builder.createNaryOp(
VPInstruction::LogicalAnd, {VPI->getMask(), VPI->getOperand(0)}));
VPI->eraseFromParent();
continue;
}

// Lower CanonicalIVIncrementForPart to plain Add.
if (match(
&R,
Expand Down Expand Up @@ -4410,6 +4394,76 @@ struct EarlyExitInfo {
VPIRBasicBlock *EarlyExitVPBB;
VPValue *CondToExit;
};
static VPValue *repairSSAImpl(VPValue *Src, VPBasicBlock *SrcVPBB,
VPValue *Other, VPBasicBlock *VPBB,
VPDominatorTree &VPDT,
DenseMap<VPBlockBase *, VPPhi *> &Phis) {

if (VPDT.dominates(SrcVPBB, VPBB))
return Src;
if (VPDT.dominates(VPBB, SrcVPBB))
return Other;
if (VPPhi *Phi = Phis.lookup(VPBB))
return Phi;

SmallVector<VPValue *> InVals;
for (auto *Pred : VPBB->predecessors())
InVals.push_back(repairSSAImpl(Src, SrcVPBB, Other,
cast<VPBasicBlock>(Pred), VPDT, Phis));
if (all_equal(InVals))
return InVals[0];

VPPhi *Phi = VPBuilder(VPBB, VPBB->getFirstNonPhi()).createScalarPhi(InVals);
Phis[VPBB] = Phi;
return Phi;
}

/// Insert phi nodes to maintain SSA starting from \p VPBB, such that the
/// resulting value is \p \Src on all paths that go through \p SrcVPBB, and \p
/// Other otherwise. Use if the CFG has been modified such that a def no longer
/// dominates all its uses.
static VPValue *repairSSA(VPValue *Src, VPBasicBlock *SrcVPBB, VPValue *Other,
VPBasicBlock *VPBB, VPDominatorTree &VPDT) {
DenseMap<VPBlockBase *, VPPhi *> Phis;
return repairSSAImpl(Src, SrcVPBB, Other, VPBB, VPDT, Phis);
}

// After handling early exits, the CondToExits and live outs may no longer be in
// SSA if their defining blocks are predicated, so insert phis to repair them.
static void repairEarlyExitSSA(VPlan &Plan, VPDominatorTree &VPDT,
ArrayRef<EarlyExitInfo> Exits,
VPBasicBlock *LatchVPBB,
ArrayRef<VPBasicBlock *> LiveOutVPBBs) {
// Repair all CondToExits. The condition is false on any path that doesn't go
// through the exiting block.
for (auto [EarlyExitingVPBB, _, CondToExit] : Exits) {
VPValue *Repaired = repairSSA(CondToExit, EarlyExitingVPBB, Plan.getFalse(),
LatchVPBB, VPDT);

CondToExit->replaceUsesWithIf(Repaired, [&](VPUser &U, unsigned I) {
auto &R = cast<VPRecipeBase>(U);
return VPDT.dominates(LatchVPBB, R.getParent()) &&
R.getVPSingleValue() != Repaired;
});
}

// Repair any live outs. The value is poison on any path that didn't pass
// through the def's block.
for (VPBasicBlock *LiveOutVPBB : LiveOutVPBBs)
for (VPRecipeBase &R : *LiveOutVPBB) {
VPValue *LiveOut;
if (!match(&R,
m_CombineOr(m_ExtractLastPart(m_VPValue(LiveOut)),
m_ExtractLane(m_VPValue(), m_VPValue(LiveOut)))))
continue;
VPValue *Poison =
Plan.getOrAddLiveIn(PoisonValue::get(LiveOut->getScalarType()));
VPValue *Repaired =
repairSSA(LiveOut, LiveOut->getDefiningRecipe()->getParent(), Poison,
LatchVPBB, VPDT);
R.replaceUsesOfWith(LiveOut, Repaired);
}
}

/// Update \p Plan to mask memory operations in the loop based on whether the
/// early exit is taken or not.
Expand Down Expand Up @@ -4615,14 +4669,12 @@ bool VPlanTransforms::handleUncountableEarlyExits(
m_BranchOnCond(m_VPValue(CondOfEarlyExitingVPBB)));
assert(Matched && "Terminator must be BranchOnCond");

// Insert the MaskedCond in the EarlyExitingVPBB so the predicator adds
// the correct block mask.
VPBuilder EarlyExitingBuilder(EarlyExitingVPBB->getTerminator());
auto *CondToEarlyExit = EarlyExitingBuilder.createNaryOp(
VPInstruction::MaskedCond,
auto *CondToEarlyExit =
TrueSucc == ExitBlock
? CondOfEarlyExitingVPBB
: EarlyExitingBuilder.createNot(CondOfEarlyExitingVPBB));
: EarlyExitingBuilder.createNot(CondOfEarlyExitingVPBB);

assert((isa<VPIRValue>(CondOfEarlyExitingVPBB) ||
!VPDT.properlyDominates(EarlyExitingVPBB, LatchVPBB) ||
VPDT.properlyDominates(
Expand Down Expand Up @@ -4819,6 +4871,11 @@ bool VPlanTransforms::handleUncountableEarlyExits(
DispatchBuilder.setInsertPoint(CurrentBB);
}

VPDT.recalculate(Plan);
SmallVector<VPBasicBlock *> LiveOutVPBBs = {MiddleVPBB};
append_range(LiveOutVPBBs, VectorEarlyExitVPBBs);
repairEarlyExitSSA(Plan, VPDT, Exits, LatchVPBB, LiveOutVPBBs);

return true;
}

Expand Down
Loading