diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 95 |
1 files changed, 37 insertions, 58 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 0156fe1c0ec..2ca3f3e452c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -644,8 +644,13 @@ static ConstantSDNode *isConstOrConstSplat(SDValue N) { if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) return CN; - if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) - return BV->getConstantSplatValue(); + if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) { + ConstantSDNode *CN = BV->getConstantSplatValue(); + + // BuildVectors can truncate their operands. Ignore that case here. + if (CN && CN->getValueType(0) == N.getValueType().getScalarType()) + return CN; + } return nullptr; } @@ -1957,8 +1962,8 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode()); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold vector ops @@ -1985,25 +1990,15 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { N0, N1); } - const APInt *Divisor = nullptr; - if (N1C) { - Divisor = &N1C->getAPIntValue(); - } else if (N1.getValueType().isVector() && - N1->getOpcode() == ISD::BUILD_VECTOR) { - BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1)); - if (ConstantSDNode *C = BV->getConstantSplatValue()) - Divisor = &C->getAPIntValue(); - } - // fold (sdiv X, pow2) -> simple ops after legalize - if (Divisor && !!*Divisor && - (Divisor->isPowerOf2() || (-*Divisor).isPowerOf2())) { + if (N1C && !N1C->isNullValue() && (N1C->getAPIntValue().isPowerOf2() || + (-N1C->getAPIntValue()).isPowerOf2())) { // If dividing by powers of two is cheap, then don't perform the following // fold. if (TLI.isPow2DivCheap()) return SDValue(); - unsigned lg2 = Divisor->countTrailingZeros(); + unsigned lg2 = N1C->getAPIntValue().countTrailingZeros(); // Splat the sign bit into the register SDValue SGN = @@ -2025,7 +2020,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { // If we're dividing by a positive value, we're done. Otherwise, we must // negate the result. - if (Divisor->isNonNegative()) + if (N1C->getAPIntValue().isNonNegative()) return SRA; AddToWorkList(SRA.getNode()); @@ -2034,7 +2029,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { // if integer divide is expensive and we satisfy the requirements, emit an // alternate sequence. - if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) { + if (N1C && !TLI.isIntDivCheap()) { SDValue Op = BuildSDIV(N); if (Op.getNode()) return Op; } @@ -2052,8 +2047,8 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0.getNode()); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1.getNode()); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold vector ops @@ -2086,7 +2081,7 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { } } // fold (udiv x, c) -> alternate - if ((N1C || N1->getOpcode() == ISD::BUILD_VECTOR) && !TLI.isIntDivCheap()) { + if (N1C && !TLI.isIntDivCheap()) { SDValue Op = BuildUDIV(N); if (Op.getNode()) return Op; } @@ -2104,8 +2099,8 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { SDValue DAGCombiner::visitSREM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold (srem c1, c2) -> c1%c2 @@ -2146,8 +2141,8 @@ SDValue DAGCombiner::visitSREM(SDNode *N) { SDValue DAGCombiner::visitUREM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0); - ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1); + ConstantSDNode *N0C = isConstOrConstSplat(N0); + ConstantSDNode *N1C = isConstOrConstSplat(N1); EVT VT = N->getValueType(0); // fold (urem c1, c2) -> c1%c2 @@ -11187,28 +11182,20 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, /// multiplying by a magic number. See: /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html> SDValue DAGCombiner::BuildSDIV(SDNode *N) { - const APInt *Divisor; - if (N->getValueType(0).isVector()) { - // Handle splat vectors. - BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1)); - if (ConstantSDNode *C = BV->getConstantSplatValue()) - Divisor = &C->getAPIntValue(); - else - return SDValue(); - } else { - Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue(); - } + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); // Avoid division by zero. - if (!*Divisor) + if (!C->getAPIntValue()) return SDValue(); std::vector<SDNode*> Built; - SDValue S = TLI.BuildSDIV(N, *Divisor, DAG, LegalOperations, &Built); + SDValue S = + TLI.BuildSDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); - for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end(); - ii != ee; ++ii) - AddToWorkList(*ii); + for (SDNode *N : Built) + AddToWorkList(N); return S; } @@ -11217,28 +11204,20 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) { /// multiplying by a magic number. See: /// <http://the.wall.riscom.net/books/proc/ppc/cwg/code2.html> SDValue DAGCombiner::BuildUDIV(SDNode *N) { - const APInt *Divisor; - if (N->getValueType(0).isVector()) { - // Handle splat vectors. - BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N->getOperand(1)); - if (ConstantSDNode *C = BV->getConstantSplatValue()) - Divisor = &C->getAPIntValue(); - else - return SDValue(); - } else { - Divisor = &cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue(); - } + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); // Avoid division by zero. - if (!*Divisor) + if (!C->getAPIntValue()) return SDValue(); std::vector<SDNode*> Built; - SDValue S = TLI.BuildUDIV(N, *Divisor, DAG, LegalOperations, &Built); + SDValue S = + TLI.BuildUDIV(N, C->getAPIntValue(), DAG, LegalOperations, &Built); - for (std::vector<SDNode*>::iterator ii = Built.begin(), ee = Built.end(); - ii != ee; ++ii) - AddToWorkList(*ii); + for (SDNode *N : Built) + AddToWorkList(N); return S; } |