diff options
author | Sanjay Patel <spatel@rotateright.com> | 2014-10-16 18:48:17 +0000 |
---|---|---|
committer | Sanjay Patel <spatel@rotateright.com> | 2014-10-16 18:48:17 +0000 |
commit | c699a6117b0f33739cdbe63fff46f95c79b5133b (patch) | |
tree | e29dbf975378a9a26369ce78ae503317baaf6026 /llvm/lib/Transforms | |
parent | d70f3c20b8c0ff71638ac2ee774b4e5a021be521 (diff) | |
download | bcm5719-llvm-c699a6117b0f33739cdbe63fff46f95c79b5133b.tar.gz bcm5719-llvm-c699a6117b0f33739cdbe63fff46f95c79b5133b.zip |
fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)
If a square root call has an FP multiplication argument that can be reassociated,
then we can hoist a repeated factor out of the square root call and into a fabs().
In the simplest case, this:
y = sqrt(x * x);
becomes this:
y = fabs(x);
This patch relies on an earlier optimization in instcombine or reassociate to put the
multiplication tree into a canonical form, so we don't have to search over
every permutation of the multiplication tree.
Because there are no IR-level FastMathFlags for intrinsics (PR21290), we have to
use function-level attributes to do this optimization. This needs to be fixed
for both the intrinsics and in the backend.
Differential Revision: http://reviews.llvm.org/D5787
llvm-svn: 219944
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp | 88 |
1 files changed, 87 insertions, 1 deletions
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 9fac7ef540e..c3e2f3aec00 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -27,12 +27,14 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" using namespace llvm; +using namespace PatternMatch; static cl::opt<bool> ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden, @@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { return Ret; } +Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "sqrt" && + TLI->has(LibFunc::sqrtf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } + + // FIXME: For finer-grain optimization, we need intrinsics to have the same + // fast-math flag decorations that are applied to FP instructions. For now, + // we have to rely on the function-level unsafe-fp-math attribute to do this + // optimization because there's no other way to express that the sqrt can be + // reassociated. + Function *F = CI->getParent()->getParent(); + if (F->hasFnAttribute("unsafe-fp-math")) { + // Check for unsafe-fp-math = true. + Attribute Attr = F->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() != "true") + return Ret; + } + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) { + // We're looking for a repeated factor in a multiplication tree, + // so we can do this fold: sqrt(x * x) -> fabs(x); + // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y). + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + Value *RepeatOp = nullptr; + Value *OtherOp = nullptr; + if (Op0 == Op1) { + // Simple match: the operands of the multiply are identical. + RepeatOp = Op0; + } else { + // Look for a more complicated pattern: one of the operands is itself + // a multiply, so search for a common factor in that multiply. + // Note: We don't bother looking any deeper than this first level or for + // variations of this pattern because instcombine's visitFMUL and/or the + // reassociation pass should give us this form. + Value *OtherMul0, *OtherMul1; + if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) { + // Pattern: sqrt((x * y) * z) + if (OtherMul0 == OtherMul1) { + // Matched: sqrt((x * x) * z) + RepeatOp = OtherMul0; + OtherOp = Op1; + } + } + } + if (RepeatOp) { + // Fast math flags for any created instructions should match the sqrt + // and multiply. + // FIXME: We're not checking the sqrt because it doesn't have + // fast-math-flags (see earlier comment). + IRBuilder<true, ConstantFolder, + IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B); + B.SetFastMathFlags(I->getFastMathFlags()); + // If we found a repeated factor, hoist it out of the square root and + // replace it with the fabs of that factor. + Module *M = Callee->getParent(); + Type *ArgType = Op->getType(); + Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); + Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); + if (OtherOp) { + // If we found a non-repeated factor, we still need to get its square + // root. We then multiply that by the value that was simplified out + // of the square root calculation. + Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); + Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); + return B.CreateFMul(FabsCall, SqrtCall); + } + return FabsCall; + } + } + } + return Ret; +} + static bool isTrigLibCall(CallInst *CI); static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, bool UseFloat, Value *&Sin, Value *&Cos, @@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeExp2(CI, Builder); case Intrinsic::fabs: return optimizeFabs(CI, Builder); + case Intrinsic::sqrt: + return optimizeSqrt(CI, Builder); default: return nullptr; } @@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc::fabs: case LibFunc::fabsl: return optimizeFabs(CI, Builder); + case LibFunc::sqrtf: + case LibFunc::sqrt: + case LibFunc::sqrtl: + return optimizeSqrt(CI, Builder); case LibFunc::ffs: case LibFunc::ffsl: case LibFunc::ffsll: @@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc::logb: case LibFunc::sin: case LibFunc::sinh: - case LibFunc::sqrt: case LibFunc::tan: case LibFunc::tanh: if (UnsafeFPShrink && hasFloatVersion(FuncName)) |