diff options
Diffstat (limited to 'llvm/lib/Analysis/CostModel.cpp')
-rw-r--r-- | llvm/lib/Analysis/CostModel.cpp | 110 |
1 files changed, 69 insertions, 41 deletions
diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp index e3fa10d40fd..071e23e90ff 100644 --- a/llvm/lib/Analysis/CostModel.cpp +++ b/llvm/lib/Analysis/CostModel.cpp @@ -24,12 +24,14 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +using namespace PatternMatch; #define CM_NAME "cost-model" #define DEBUG_TYPE CM_NAME @@ -183,27 +185,46 @@ static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, return Mask == ActualMask; } -static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, - unsigned Level, unsigned NumLevels) { +namespace { +/// Contains opcode + LHS/RHS parts of the reduction operations. +struct ReductionData { + explicit ReductionData() = default; + ReductionData(unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS) {} + unsigned Opcode = 0; + Value *LHS = nullptr; + Value *RHS = nullptr; +}; +} // namespace + +static Optional<ReductionData> getReductionData(Instruction *I) { + Value *L, *R; + if (m_BinOp(m_Value(L), m_Value(R)).match(I)) + return ReductionData(I->getOpcode(), L, R); + return llvm::None; +} + +static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 - if (BinOp == nullptr) + if (!I) return false; - assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); + assert(I->getType()->isVectorTy() && "Expecting a vector type"); - unsigned Opcode = BinOp->getOpcode(); - Value *L = BinOp->getOperand(0); - Value *R = BinOp->getOperand(1); + Optional<ReductionData> RD = getReductionData(I); + if (!RD) + return false; - ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L); + ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS); if (!LS && Level) return false; - ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R); + ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS); if (!RS && Level) return false; @@ -228,31 +249,30 @@ static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, // Example: // %NextLevelOpL = shufflevector %R, <1, undef ...> // %BinOp = fadd %NextLevelOpL, %R - if (NextLevelOpL && NextLevelOpL != R) + if (NextLevelOpL && NextLevelOpL != RD->RHS) return false; - else if (NextLevelOpR && NextLevelOpR != L) + else if (NextLevelOpR && NextLevelOpR != RD->LHS) return false; - NextLevelOp = NextLevelOpL ? R : L; + NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; } else return false; // Check that the next levels binary operation exists and matches with the // current one. - BinaryOperator *NextLevelBinOp = nullptr; if (Level + 1 != NumLevels) { - if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp))) - return false; - else if (NextLevelBinOp->getOpcode() != Opcode) + Optional<ReductionData> NextLevelRD = + getReductionData(cast<Instruction>(NextLevelOp)); + if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode) return false; } // Shuffle mask for pairwise operation must match. - if (matchPairwiseShuffleMask(LS, true, Level)) { - if (!matchPairwiseShuffleMask(RS, false, Level)) + if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) return false; - } else if (matchPairwiseShuffleMask(RS, true, Level)) { - if (!matchPairwiseShuffleMask(LS, false, Level)) + } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) return false; } else return false; @@ -261,7 +281,8 @@ static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, return true; // Match next level. - return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); + return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level, + NumLevels); } static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, @@ -277,11 +298,14 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); + auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); if (!RdxStart) return false; + Optional<ReductionData> RD = getReductionData(RdxStart); + if (!RD) + return false; - Type *VecTy = ReduxRoot->getOperand(0)->getType(); + Type *VecTy = RdxStart->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) return false; @@ -307,17 +331,14 @@ static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) return false; - Opcode = RdxStart->getOpcode(); + Opcode = RD->Opcode; Ty = VecTy; return true; } static std::pair<Value *, ShuffleVectorInst *> -getShuffleAndOtherOprd(BinaryOperator *B) { - - Value *L = B->getOperand(0); - Value *R = B->getOperand(1); +getShuffleAndOtherOprd(Value *L, Value *R) { ShuffleVectorInst *S = nullptr; if ((S = dyn_cast<ShuffleVectorInst>(L))) @@ -340,10 +361,12 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); + auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); if (!RdxStart) return false; - unsigned RdxOpcode = RdxStart->getOpcode(); + Optional<ReductionData> RD = getReductionData(RdxStart); + if (!RD) + return false; Type *VecTy = ReduxRoot->getOperand(0)->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); @@ -362,20 +385,21 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, // %r = extractelement <4 x float> %bin.rdx8, i32 0 unsigned MaskStart = 1; - Value *RdxOp = RdxStart; + Instruction *RdxOp = RdxStart; SmallVector<int, 32> ShuffleMask(NumVecElems, 0); unsigned NumVecElemsRemain = NumVecElems; while (NumVecElemsRemain - 1) { // Check for the right reduction operation. - BinaryOperator *BinOp; - if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp))) + if (!RdxOp) return false; - if (BinOp->getOpcode() != RdxOpcode) + Optional<ReductionData> RDLevel = getReductionData(RdxOp); + if (!RDLevel || RDLevel->Opcode != RD->Opcode) return false; Value *NextRdxOp; ShuffleVectorInst *Shuffle; - std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); + std::tie(NextRdxOp, Shuffle) = + getShuffleAndOtherOprd(RDLevel->LHS, RDLevel->RHS); // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) @@ -393,12 +417,12 @@ static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, if (ShuffleMask != Mask) return false; - RdxOp = NextRdxOp; + RdxOp = dyn_cast<Instruction>(NextRdxOp); NumVecElemsRemain /= 2; MaskStart *= 2; } - Opcode = RdxOpcode; + Opcode = RD->Opcode; Ty = VecTy; return true; } @@ -495,10 +519,14 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { unsigned ReduxOpCode; Type *ReduxType; - if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, false); - else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, true); + if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/false); + } + if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/true); + } return TTI->getVectorInstrCost(I->getOpcode(), EEI->getOperand(0)->getType(), Idx); |