diff options
Diffstat (limited to 'llvm/lib/Target/ARM/ARMISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/ARM/ARMISelLowering.cpp | 133 |
1 files changed, 128 insertions, 5 deletions
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 62953f4be18..15ae0c7940b 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -1337,6 +1337,10 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const { case ARMISD::UMAAL: return "ARMISD::UMAAL"; case ARMISD::UMLAL: return "ARMISD::UMLAL"; case ARMISD::SMLAL: return "ARMISD::SMLAL"; + case ARMISD::SMLALBB: return "ARMISD::SMLALBB"; + case ARMISD::SMLALBT: return "ARMISD::SMLALBT"; + case ARMISD::SMLALTB: return "ARMISD::SMLALTB"; + case ARMISD::SMLALTT: return "ARMISD::SMLALTT"; case ARMISD::SMULWB: return "ARMISD::SMULWB"; case ARMISD::SMULWT: return "ARMISD::SMULWT"; case ARMISD::BUILD_VECTOR: return "ARMISD::BUILD_VECTOR"; @@ -9497,8 +9501,90 @@ static SDValue findMUL_LOHI(SDValue V) { return SDValue(); } +static SDValue AddCombineTo64BitSMLAL16(SDNode *AddcNode, SDNode *AddeNode, + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { + + if (Subtarget->isThumb()) { + if (!Subtarget->hasDSP()) + return SDValue(); + } else if (!Subtarget->hasV5TEOps()) + return SDValue(); + + // SMLALBB, SMLALBT, SMLALTB, SMLALTT multiply two 16-bit values and + // accumulates the product into a 64-bit value. The 16-bit values will + // be sign extended somehow or SRA'd into 32-bit values + // (addc (adde (mul 16bit, 16bit), lo), hi) + SDValue Mul = AddcNode->getOperand(0); + SDValue Hi = AddcNode->getOperand(1); + if (Mul.getOpcode() != ISD::MUL) { + Hi = AddcNode->getOperand(0); + Mul = AddcNode->getOperand(1); + if (Mul.getOpcode() != ISD::MUL) + return SDValue(); + } + + SDValue SRA = AddeNode->getOperand(0); + SDValue Lo = AddeNode->getOperand(1); + if (SRA.getOpcode() != ISD::SRA) { + SRA = AddeNode->getOperand(1); + Lo = AddeNode->getOperand(0); + if (SRA.getOpcode() != ISD::SRA) + return SDValue(); + } + if (auto Const = dyn_cast<ConstantSDNode>(SRA.getOperand(1))) { + if (Const->getZExtValue() != 31) + return SDValue(); + } else + return SDValue(); + + if (SRA.getOperand(0) != Mul) + return SDValue(); + + SelectionDAG &DAG = DCI.DAG; + SDLoc dl(AddcNode); + unsigned Opcode = 0; + SDValue Op0; + SDValue Op1; + + if (isS16(Mul.getOperand(0), DAG) && isS16(Mul.getOperand(1), DAG)) { + Opcode = ARMISD::SMLALBB; + Op0 = Mul.getOperand(0); + Op1 = Mul.getOperand(1); + } else if (isS16(Mul.getOperand(0), DAG) && isSRA16(Mul.getOperand(1))) { + Opcode = ARMISD::SMLALBT; + Op0 = Mul.getOperand(0); + Op1 = Mul.getOperand(1).getOperand(0); + } else if (isSRA16(Mul.getOperand(0)) && isS16(Mul.getOperand(1), DAG)) { + Opcode = ARMISD::SMLALTB; + Op0 = Mul.getOperand(0).getOperand(0); + Op1 = Mul.getOperand(1); + } else if (isSRA16(Mul.getOperand(0)) && isSRA16(Mul.getOperand(1))) { + Opcode = ARMISD::SMLALTT; + Op0 = Mul->getOperand(0).getOperand(0); + Op1 = Mul->getOperand(1).getOperand(0); + } + + if (!Op0 || !Op1) + return SDValue(); + + SDValue SMLAL = DAG.getNode(Opcode, dl, DAG.getVTList(MVT::i32, MVT::i32), + Op0, Op1, Lo, Hi); + // Replace the ADDs' nodes uses by the MLA node's values. + SDValue HiMLALResult(SMLAL.getNode(), 1); + SDValue LoMLALResult(SMLAL.getNode(), 0); + + DAG.ReplaceAllUsesOfValueWith(SDValue(AddcNode, 0), LoMLALResult); + DAG.ReplaceAllUsesOfValueWith(SDValue(AddeNode, 0), HiMLALResult); + + // Return original node to notify the driver to stop replacing. + SDValue resNode(AddcNode, 0); + return resNode; +} + static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const ARMSubtarget *Subtarget) { // Look for multiply add opportunities. // The pattern is a ISD::UMUL_LOHI followed by two add nodes, where // each add nodes consumes a value from ISD::UMUL_LOHI and there is @@ -9535,12 +9621,13 @@ static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode, AddcNode->getValueType(0) == MVT::i32 && "Expect ADDC with two result values. First: i32"); - // Check that the ADDC adds the low result of the S/UMUL_LOHI. + // Check that the ADDC adds the low result of the S/UMUL_LOHI. If not, it + // maybe a SMLAL which multiplies two 16-bit values. if (AddcOp0->getOpcode() != ISD::UMUL_LOHI && AddcOp0->getOpcode() != ISD::SMUL_LOHI && AddcOp1->getOpcode() != ISD::UMUL_LOHI && AddcOp1->getOpcode() != ISD::SMUL_LOHI) - return SDValue(); + return AddCombineTo64BitSMLAL16(AddcNode, AddeNode, DCI, Subtarget); // Check for the triangle shape. SDValue AddeOp0 = AddeNode->getOperand(0); @@ -9628,7 +9715,7 @@ static SDValue AddCombineTo64bitUMAAL(SDNode *AddeNode, // as the addend, and it's handled in PerformUMLALCombine. if (!Subtarget->hasV6Ops() || !Subtarget->hasDSP()) - return AddCombineTo64bitMLAL(AddeNode, DCI); + return AddCombineTo64bitMLAL(AddeNode, DCI, Subtarget); // Check that we have a glued ADDC node. SDNode* AddcNode = AddeNode->getOperand(2).getNode(); @@ -9645,7 +9732,7 @@ static SDValue AddCombineTo64bitUMAAL(SDNode *AddeNode, UmlalNode = AddcNode->getOperand(1).getNode(); AddHi = AddcNode->getOperand(0); } else { - return AddCombineTo64bitMLAL(AddeNode, DCI); + return AddCombineTo64bitMLAL(AddeNode, DCI, Subtarget); } // The ADDC should be glued to an ADDE node, which uses the same UMLAL as @@ -11894,6 +11981,42 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, return SDValue(); break; } + case ARMISD::SMLALBB: { + unsigned BitWidth = N->getValueType(0).getSizeInBits(); + APInt DemandedMask = APInt::getLowBitsSet(BitWidth, 16); + if ((SimplifyDemandedBits(N->getOperand(0), DemandedMask, DCI)) || + (SimplifyDemandedBits(N->getOperand(1), DemandedMask, DCI))) + return SDValue(); + break; + } + case ARMISD::SMLALBT: { + unsigned LowWidth = N->getOperand(0).getValueType().getSizeInBits(); + APInt LowMask = APInt::getLowBitsSet(LowWidth, 16); + unsigned HighWidth = N->getOperand(1).getValueType().getSizeInBits(); + APInt HighMask = APInt::getHighBitsSet(HighWidth, 16); + if ((SimplifyDemandedBits(N->getOperand(0), LowMask, DCI)) || + (SimplifyDemandedBits(N->getOperand(1), HighMask, DCI))) + return SDValue(); + break; + } + case ARMISD::SMLALTB: { + unsigned HighWidth = N->getOperand(0).getValueType().getSizeInBits(); + APInt HighMask = APInt::getHighBitsSet(HighWidth, 16); + unsigned LowWidth = N->getOperand(1).getValueType().getSizeInBits(); + APInt LowMask = APInt::getLowBitsSet(LowWidth, 16); + if ((SimplifyDemandedBits(N->getOperand(0), HighMask, DCI)) || + (SimplifyDemandedBits(N->getOperand(1), LowMask, DCI))) + return SDValue(); + break; + } + case ARMISD::SMLALTT: { + unsigned BitWidth = N->getValueType(0).getSizeInBits(); + APInt DemandedMask = APInt::getHighBitsSet(BitWidth, 16); + if ((SimplifyDemandedBits(N->getOperand(0), DemandedMask, DCI)) || + (SimplifyDemandedBits(N->getOperand(1), DemandedMask, DCI))) + return SDValue(); + break; + } case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) { |

