summaryrefslogtreecommitdiffstats
path: root/llvm
diff options
context:
space:
mode:
authorSanjay Patel <spatel@rotateright.com>2019-12-11 11:18:09 -0500
committerSanjay Patel <spatel@rotateright.com>2019-12-11 13:30:39 -0500
commitd1f0bdf2d2df9bdf11ee2ddfff3df50e53f2f042 (patch)
treed8ff7e8e6e718370607f772a1a62f737375cf00d /llvm
parentd8c31d41989b0748e2e5b8d7fa9cf7e7023bcbff (diff)
downloadbcm5719-llvm-d1f0bdf2d2df9bdf11ee2ddfff3df50e53f2f042.tar.gz
bcm5719-llvm-d1f0bdf2d2df9bdf11ee2ddfff3df50e53f2f042.zip
[SDAG] remove use restriction in isNegatibleForFree() when called from getNegatedExpression()
This is an alternate fix for the bug discussed in D70595. This also includes minimal tests for other in-tree targets to show the problem more generally. We check the number of uses as a predicate for whether some value is free to negate, but that use count can change as we rewrite the expression in getNegatedExpression(). So something that was marked free to negate during the cost evaluation phase becomes not free to negate during the rewrite phase (or the inverse - something that was not free becomes free). This can lead to a crash/assert because we expect that everything in an expression that is negatible to be handled in the corresponding code within getNegatedExpression(). This patch skips the use check during the rewrite phase. So we determine that some expression isNegatibleForFree (identically to without this patch), but during the rewrite, don't rely on use counts to decide how to create the optimal expression. Differential Revision: https://reviews.llvm.org/D70975
Diffstat (limited to 'llvm')
-rw-r--r--llvm/include/llvm/CodeGen/TargetLowering.h8
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp35
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp8
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.h3
-rw-r--r--llvm/test/CodeGen/AArch64/arm64-fmadd.ll18
-rw-r--r--llvm/test/CodeGen/X86/fma-fneg-combine-2.ll20
6 files changed, 72 insertions, 20 deletions
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 0726bdfec20..687a2eb9296 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3442,8 +3442,16 @@ public:
/// Return 1 if we can compute the negated form of the specified expression
/// for the same cost as the expression itself, or 2 if we can compute the
/// negated form more cheaply than the expression itself. Else return 0.
+ ///
+ /// EnableUseCheck specifies whether the number of uses of a value affects
+ /// if negation is considered free. This is needed because the number of uses
+ /// of any value may change as we rewrite the expression. Therefore, when
+ /// called from getNegatedExpression(), we must explicitly set EnableUseCheck
+ /// to false to avoid getting a different answer than when called from other
+ /// contexts.
virtual char isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
bool LegalOperations, bool ForCodeSize,
+ bool EnableUseCheck = true,
unsigned Depth = 0) const;
/// If isNegatibleForFree returns true, return the newly negated expression.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f8afdaf086a..05011aebb9d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5413,18 +5413,21 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
bool LegalOperations, bool ForCodeSize,
+ bool EnableUseCheck,
unsigned Depth) const {
// fneg is removable even if it has multiple uses.
if (Op.getOpcode() == ISD::FNEG)
return 2;
- // Don't allow anything with multiple uses unless we know it is free.
+ // If the caller requires checking uses, don't allow anything with multiple
+ // uses unless we know it is free.
EVT VT = Op.getValueType();
const SDNodeFlags Flags = Op->getFlags();
const TargetOptions &Options = DAG.getTarget().Options;
- if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND &&
- isFPExtFree(VT, Op.getOperand(0).getValueType())))
- return 0;
+ if (EnableUseCheck)
+ if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND &&
+ isFPExtFree(VT, Op.getOperand(0).getValueType())))
+ return 0;
// Don't recurse exponentially.
if (Depth > SelectionDAG::MaxRecursionDepth)
@@ -5468,11 +5471,11 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
- ForCodeSize, Depth + 1))
+ ForCodeSize, EnableUseCheck, Depth + 1))
return V;
// fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
case ISD::FSUB:
// We can't turn -(A-B) into B-A when we honor signed zeros.
if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
@@ -5485,7 +5488,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
case ISD::FDIV:
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
- ForCodeSize, Depth + 1))
+ ForCodeSize, EnableUseCheck, Depth + 1))
return V;
// Ignore X * 2.0 because that is expected to be canonicalized to X + X.
@@ -5494,7 +5497,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
return 0;
return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
case ISD::FMA:
case ISD::FMAD: {
@@ -5504,15 +5507,15 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
// fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
// fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
char V2 = isNegatibleForFree(Op.getOperand(2), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
if (!V2)
return 0;
// One of Op0/Op1 must be cheaply negatible, then select the cheapest.
char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
char V01 = std::max(V0, V1);
return V01 ? std::max(V01, V2) : 0;
}
@@ -5521,7 +5524,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
case ISD::FP_ROUND:
case ISD::FSIN:
return isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
}
return 0;
@@ -5565,7 +5568,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
- Depth + 1))
+ false, Depth + 1))
return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(),
getNegatedExpression(Op.getOperand(0), DAG,
LegalOperations, ForCodeSize,
@@ -5592,7 +5595,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
case ISD::FDIV:
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
- Depth + 1))
+ false, Depth + 1))
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
getNegatedExpression(Op.getOperand(0), DAG,
LegalOperations, ForCodeSize,
@@ -5616,9 +5619,9 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
ForCodeSize, Depth + 1);
char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, false, Depth + 1);
char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, false, Depth + 1);
if (V0 >= V1) {
// fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
SDValue Neg0 = getNegatedExpression(
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 866ee5b9a60..cdb588ddb8a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -41898,6 +41898,7 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
bool LegalOperations,
bool ForCodeSize,
+ bool EnableUseCheck,
unsigned Depth) const {
// fneg patterns are removable even if they have multiple uses.
if (isFNEG(DAG, Op.getNode(), Depth))
@@ -41926,7 +41927,7 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
// extra operand negations as well.
for (int i = 0; i != 3; ++i) {
char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, EnableUseCheck, Depth + 1);
if (V == 2)
return V;
}
@@ -41935,7 +41936,8 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
}
return TargetLowering::isNegatibleForFree(Op, DAG, LegalOperations,
- ForCodeSize, Depth);
+ ForCodeSize, EnableUseCheck,
+ Depth);
}
SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
@@ -41967,7 +41969,7 @@ SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
SmallVector<SDValue, 4> NewOps(Op.getNumOperands(), SDValue());
for (int i = 0; i != 3; ++i) {
char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations,
- ForCodeSize, Depth + 1);
+ ForCodeSize, false, Depth + 1);
if (V == 2)
NewOps[i] = getNegatedExpression(Op.getOperand(i), DAG, LegalOperations,
ForCodeSize, Depth + 1);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 01612006413..3bbf3b59ac5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -809,7 +809,8 @@ namespace llvm {
/// for the same cost as the expression itself, or 2 if we can compute the
/// negated form more cheaply than the expression itself. Else return 0.
char isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations,
- bool ForCodeSize, unsigned Depth) const override;
+ bool ForCodeSize, bool EnableUseCheck,
+ unsigned Depth) const override;
/// If isNegatibleForFree returns true, return the newly negated expression.
SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/AArch64/arm64-fmadd.ll b/llvm/test/CodeGen/AArch64/arm64-fmadd.ll
index 203ce623647..dffa83aa11b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fmadd.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fmadd.ll
@@ -88,5 +88,23 @@ entry:
ret double %0
}
+; This would crash while trying getNegatedExpression().
+
+define float @negated_constant(float %x) {
+; CHECK-LABEL: negated_constant:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov w8, #-1037565952
+; CHECK-NEXT: mov w9, #1109917696
+; CHECK-NEXT: fmov s1, w8
+; CHECK-NEXT: fmul s1, s0, s1
+; CHECK-NEXT: fmov s2, w9
+; CHECK-NEXT: fmadd s0, s0, s2, s1
+; CHECK-NEXT: ret
+ %m = fmul float %x, 42.0
+ %fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m)
+ %nfma = fneg float %fma
+ ret float %nfma
+}
+
declare float @llvm.fma.f32(float, float, float) nounwind readnone
declare double @llvm.fma.f64(double, double, double) nounwind readnone
diff --git a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
index f9e87955270..9c846e3f555 100644
--- a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
+++ b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
@@ -86,4 +86,24 @@ entry:
ret float %1
}
+; This would crash while trying getNegatedExpression().
+
+define float @negated_constant(float %x) {
+; FMA3-LABEL: negated_constant:
+; FMA3: # %bb.0:
+; FMA3-NEXT: vmulss {{.*}}(%rip), %xmm0, %xmm1
+; FMA3-NEXT: vfmadd132ss {{.*#+}} xmm0 = (xmm0 * mem) + xmm1
+; FMA3-NEXT: retq
+;
+; FMA4-LABEL: negated_constant:
+; FMA4: # %bb.0:
+; FMA4-NEXT: vmulss {{.*}}(%rip), %xmm0, %xmm1
+; FMA4-NEXT: vfmaddss %xmm1, {{.*}}(%rip), %xmm0, %xmm0
+; FMA4-NEXT: retq
+ %m = fmul float %x, 42.0
+ %fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m)
+ %nfma = fneg float %fma
+ ret float %nfma
+}
+
declare float @llvm.fma.f32(float, float, float)
OpenPOWER on IntegriCloud