diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index b3188001e11..d03b55756d3 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -43,6 +43,11 @@ using namespace PatternMatch; static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", cl::init(true)); +static cl::opt<bool> AllowContractEnabled( + "matrix-allow-contract", cl::init(false), cl::Hidden, + cl::desc("Allow the use of FMAs if available and profitable. This may " + "result in different results, due to less rounding error.")); + namespace { // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute @@ -536,12 +541,25 @@ public: } Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, - IRBuilder<> &Builder) { - Value *Mul = UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); + IRBuilder<> &Builder, bool AllowContraction) { + if (!Sum) - return Mul; + return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); + + if (UseFPOp) { + if (AllowContraction) { + // Use fmuladd for floating point operations and let the backend decide + // if that's profitable. + Value *FMulAdd = Intrinsic::getDeclaration( + Func.getParent(), Intrinsic::fmuladd, A->getType()); + return Builder.CreateCall(FMulAdd, {A, B, Sum}); + } + Value *Mul = Builder.CreateFMul(A, B); + return Builder.CreateFAdd(Sum, Mul); + } - return UseFPOp ? Builder.CreateFAdd(Sum, Mul) : Builder.CreateAdd(Sum, Mul); + Value *Mul = Builder.CreateMul(A, B); + return Builder.CreateAdd(Sum, Mul); } /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For @@ -591,6 +609,8 @@ public: EltType->getPrimitiveSizeInBits(), uint64_t(1)); + bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && + MatMul->hasAllowContract()); // Multiply columns from the first operand with scalars from the second // operand. Then move along the K axes and accumulate the columns. With // this the adds can be vectorized without reassociation. @@ -607,7 +627,7 @@ public: Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), - Builder); + Builder, AllowContract); } Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); } |

