summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/TargetTransformInfo.cpp53
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp15
2 files changed, 65 insertions, 3 deletions
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index f3d20ce984d..6730aa86a99 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -571,11 +571,64 @@ TargetTransformInfo::getOperandInfo(Value *V, OperandValueProperties &OpProps) {
return OpInfo;
}
+Optional<int>
+TargetTransformInfo::getLoadCombineCost(unsigned Opcode,
+ ArrayRef<const Value *> Args) const {
+ if (Opcode != Instruction::Or)
+ return llvm::None;
+ if (Args.empty())
+ return llvm::None;
+
+ // Look past the reduction to find a source value. Arbitrarily follow the
+ // path through operand 0 of any 'or'. Also, peek through optional
+ // shift-left-by-constant.
+ const Value *ZextLoad = Args.front();
+ while (match(ZextLoad, m_Or(m_Value(), m_Value())) ||
+ match(ZextLoad, m_Shl(m_Value(), m_Constant())))
+ ZextLoad = cast<BinaryOperator>(ZextLoad)->getOperand(0);
+
+ // Check if the input to the reduction is an extended load.
+ Value *LoadPtr;
+ if (!match(ZextLoad, m_ZExt(m_Load(m_Value(LoadPtr)))))
+ return llvm::None;
+
+ // Require that the total load bit width is a legal integer type.
+ // For example, <8 x i8> --> i64 is a legal integer on a 64-bit target.
+ // But <16 x i8> --> i128 is not, so the backend probably can't reduce it.
+ Type *WideType = ZextLoad->getType();
+ Type *EltType = LoadPtr->getType()->getPointerElementType();
+ unsigned WideWidth = WideType->getIntegerBitWidth();
+ unsigned EltWidth = EltType->getIntegerBitWidth();
+ if (!isTypeLegal(WideType) || WideWidth % EltWidth != 0)
+ return llvm::None;
+
+ // Calculate relative cost: {narrow load+zext+shl+or} are assumed to be
+ // removed and replaced by a single wide load.
+ // FIXME: This is not accurate for the larger pattern where we replace
+ // multiple narrow load sequences with just 1 wide load. We could
+ // remove the addition of the wide load cost here and expect the caller
+ // to make an adjustment for that.
+ int Cost = 0;
+ Cost -= getMemoryOpCost(Instruction::Load, EltType, 0, 0);
+ Cost -= getCastInstrCost(Instruction::ZExt, WideType, EltType);
+ Cost -= getArithmeticInstrCost(Instruction::Shl, WideType);
+ Cost -= getArithmeticInstrCost(Instruction::Or, WideType);
+ Cost += getMemoryOpCost(Instruction::Load, WideType, 0, 0);
+ return Cost;
+}
+
+
int TargetTransformInfo::getArithmeticInstrCost(
unsigned Opcode, Type *Ty, OperandValueKind Opd1Info,
OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo,
OperandValueProperties Opd2PropInfo,
ArrayRef<const Value *> Args) const {
+ // Check if we can match this instruction as part of a larger pattern.
+ Optional<int> LoadCombineCost = getLoadCombineCost(Opcode, Args);
+ if (LoadCombineCost)
+ return LoadCombineCost.getValue();
+
+ // Fallback to implementation-specific overrides or base class.
int Cost = TTIImpl->getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info,
Opd1PropInfo, Opd2PropInfo, Args);
assert(Cost >= 0 && "TTI should not produce negative costs!");
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 99428c6c5de..ad12646bdee 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6499,10 +6499,19 @@ private:
int ScalarReduxCost = 0;
switch (ReductionData.getKind()) {
- case RK_Arithmetic:
- ScalarReduxCost =
- TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy);
+ case RK_Arithmetic: {
+ // Note: Passing in the reduction operands allows the cost model to match
+ // load combining patterns for this reduction.
+ auto *ReduxInst = cast<Instruction>(ReductionRoot);
+ SmallVector<const Value *, 2> OperandList;
+ for (Value *Operand : ReduxInst->operands())
+ OperandList.push_back(Operand);
+ ScalarReduxCost = TTI->getArithmeticInstrCost(ReductionData.getOpcode(),
+ ScalarTy, TargetTransformInfo::OK_AnyValue,
+ TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None,
+ TargetTransformInfo::OP_None, OperandList);
break;
+ }
case RK_Min:
case RK_Max:
case RK_UMin:
OpenPOWER on IntegriCloud