summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target/ARM/ARMISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/ARM/ARMISelLowering.cpp')
-rw-r--r--llvm/lib/Target/ARM/ARMISelLowering.cpp133
1 files changed, 128 insertions, 5 deletions
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 62953f4be18..15ae0c7940b 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -1337,6 +1337,10 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
case ARMISD::UMAAL: return "ARMISD::UMAAL";
case ARMISD::UMLAL: return "ARMISD::UMLAL";
case ARMISD::SMLAL: return "ARMISD::SMLAL";
+ case ARMISD::SMLALBB: return "ARMISD::SMLALBB";
+ case ARMISD::SMLALBT: return "ARMISD::SMLALBT";
+ case ARMISD::SMLALTB: return "ARMISD::SMLALTB";
+ case ARMISD::SMLALTT: return "ARMISD::SMLALTT";
case ARMISD::SMULWB: return "ARMISD::SMULWB";
case ARMISD::SMULWT: return "ARMISD::SMULWT";
case ARMISD::BUILD_VECTOR: return "ARMISD::BUILD_VECTOR";
@@ -9497,8 +9501,90 @@ static SDValue findMUL_LOHI(SDValue V) {
return SDValue();
}
+static SDValue AddCombineTo64BitSMLAL16(SDNode *AddcNode, SDNode *AddeNode,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const ARMSubtarget *Subtarget) {
+
+ if (Subtarget->isThumb()) {
+ if (!Subtarget->hasDSP())
+ return SDValue();
+ } else if (!Subtarget->hasV5TEOps())
+ return SDValue();
+
+ // SMLALBB, SMLALBT, SMLALTB, SMLALTT multiply two 16-bit values and
+ // accumulates the product into a 64-bit value. The 16-bit values will
+ // be sign extended somehow or SRA'd into 32-bit values
+ // (addc (adde (mul 16bit, 16bit), lo), hi)
+ SDValue Mul = AddcNode->getOperand(0);
+ SDValue Hi = AddcNode->getOperand(1);
+ if (Mul.getOpcode() != ISD::MUL) {
+ Hi = AddcNode->getOperand(0);
+ Mul = AddcNode->getOperand(1);
+ if (Mul.getOpcode() != ISD::MUL)
+ return SDValue();
+ }
+
+ SDValue SRA = AddeNode->getOperand(0);
+ SDValue Lo = AddeNode->getOperand(1);
+ if (SRA.getOpcode() != ISD::SRA) {
+ SRA = AddeNode->getOperand(1);
+ Lo = AddeNode->getOperand(0);
+ if (SRA.getOpcode() != ISD::SRA)
+ return SDValue();
+ }
+ if (auto Const = dyn_cast<ConstantSDNode>(SRA.getOperand(1))) {
+ if (Const->getZExtValue() != 31)
+ return SDValue();
+ } else
+ return SDValue();
+
+ if (SRA.getOperand(0) != Mul)
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc dl(AddcNode);
+ unsigned Opcode = 0;
+ SDValue Op0;
+ SDValue Op1;
+
+ if (isS16(Mul.getOperand(0), DAG) && isS16(Mul.getOperand(1), DAG)) {
+ Opcode = ARMISD::SMLALBB;
+ Op0 = Mul.getOperand(0);
+ Op1 = Mul.getOperand(1);
+ } else if (isS16(Mul.getOperand(0), DAG) && isSRA16(Mul.getOperand(1))) {
+ Opcode = ARMISD::SMLALBT;
+ Op0 = Mul.getOperand(0);
+ Op1 = Mul.getOperand(1).getOperand(0);
+ } else if (isSRA16(Mul.getOperand(0)) && isS16(Mul.getOperand(1), DAG)) {
+ Opcode = ARMISD::SMLALTB;
+ Op0 = Mul.getOperand(0).getOperand(0);
+ Op1 = Mul.getOperand(1);
+ } else if (isSRA16(Mul.getOperand(0)) && isSRA16(Mul.getOperand(1))) {
+ Opcode = ARMISD::SMLALTT;
+ Op0 = Mul->getOperand(0).getOperand(0);
+ Op1 = Mul->getOperand(1).getOperand(0);
+ }
+
+ if (!Op0 || !Op1)
+ return SDValue();
+
+ SDValue SMLAL = DAG.getNode(Opcode, dl, DAG.getVTList(MVT::i32, MVT::i32),
+ Op0, Op1, Lo, Hi);
+ // Replace the ADDs' nodes uses by the MLA node's values.
+ SDValue HiMLALResult(SMLAL.getNode(), 1);
+ SDValue LoMLALResult(SMLAL.getNode(), 0);
+
+ DAG.ReplaceAllUsesOfValueWith(SDValue(AddcNode, 0), LoMLALResult);
+ DAG.ReplaceAllUsesOfValueWith(SDValue(AddeNode, 0), HiMLALResult);
+
+ // Return original node to notify the driver to stop replacing.
+ SDValue resNode(AddcNode, 0);
+ return resNode;
+}
+
static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode,
- TargetLowering::DAGCombinerInfo &DCI) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ const ARMSubtarget *Subtarget) {
// Look for multiply add opportunities.
// The pattern is a ISD::UMUL_LOHI followed by two add nodes, where
// each add nodes consumes a value from ISD::UMUL_LOHI and there is
@@ -9535,12 +9621,13 @@ static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode,
AddcNode->getValueType(0) == MVT::i32 &&
"Expect ADDC with two result values. First: i32");
- // Check that the ADDC adds the low result of the S/UMUL_LOHI.
+ // Check that the ADDC adds the low result of the S/UMUL_LOHI. If not, it
+ // maybe a SMLAL which multiplies two 16-bit values.
if (AddcOp0->getOpcode() != ISD::UMUL_LOHI &&
AddcOp0->getOpcode() != ISD::SMUL_LOHI &&
AddcOp1->getOpcode() != ISD::UMUL_LOHI &&
AddcOp1->getOpcode() != ISD::SMUL_LOHI)
- return SDValue();
+ return AddCombineTo64BitSMLAL16(AddcNode, AddeNode, DCI, Subtarget);
// Check for the triangle shape.
SDValue AddeOp0 = AddeNode->getOperand(0);
@@ -9628,7 +9715,7 @@ static SDValue AddCombineTo64bitUMAAL(SDNode *AddeNode,
// as the addend, and it's handled in PerformUMLALCombine.
if (!Subtarget->hasV6Ops() || !Subtarget->hasDSP())
- return AddCombineTo64bitMLAL(AddeNode, DCI);
+ return AddCombineTo64bitMLAL(AddeNode, DCI, Subtarget);
// Check that we have a glued ADDC node.
SDNode* AddcNode = AddeNode->getOperand(2).getNode();
@@ -9645,7 +9732,7 @@ static SDValue AddCombineTo64bitUMAAL(SDNode *AddeNode,
UmlalNode = AddcNode->getOperand(1).getNode();
AddHi = AddcNode->getOperand(0);
} else {
- return AddCombineTo64bitMLAL(AddeNode, DCI);
+ return AddCombineTo64bitMLAL(AddeNode, DCI, Subtarget);
}
// The ADDC should be glued to an ADDE node, which uses the same UMLAL as
@@ -11894,6 +11981,42 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
return SDValue();
break;
}
+ case ARMISD::SMLALBB: {
+ unsigned BitWidth = N->getValueType(0).getSizeInBits();
+ APInt DemandedMask = APInt::getLowBitsSet(BitWidth, 16);
+ if ((SimplifyDemandedBits(N->getOperand(0), DemandedMask, DCI)) ||
+ (SimplifyDemandedBits(N->getOperand(1), DemandedMask, DCI)))
+ return SDValue();
+ break;
+ }
+ case ARMISD::SMLALBT: {
+ unsigned LowWidth = N->getOperand(0).getValueType().getSizeInBits();
+ APInt LowMask = APInt::getLowBitsSet(LowWidth, 16);
+ unsigned HighWidth = N->getOperand(1).getValueType().getSizeInBits();
+ APInt HighMask = APInt::getHighBitsSet(HighWidth, 16);
+ if ((SimplifyDemandedBits(N->getOperand(0), LowMask, DCI)) ||
+ (SimplifyDemandedBits(N->getOperand(1), HighMask, DCI)))
+ return SDValue();
+ break;
+ }
+ case ARMISD::SMLALTB: {
+ unsigned HighWidth = N->getOperand(0).getValueType().getSizeInBits();
+ APInt HighMask = APInt::getHighBitsSet(HighWidth, 16);
+ unsigned LowWidth = N->getOperand(1).getValueType().getSizeInBits();
+ APInt LowMask = APInt::getLowBitsSet(LowWidth, 16);
+ if ((SimplifyDemandedBits(N->getOperand(0), HighMask, DCI)) ||
+ (SimplifyDemandedBits(N->getOperand(1), LowMask, DCI)))
+ return SDValue();
+ break;
+ }
+ case ARMISD::SMLALTT: {
+ unsigned BitWidth = N->getValueType(0).getSizeInBits();
+ APInt DemandedMask = APInt::getHighBitsSet(BitWidth, 16);
+ if ((SimplifyDemandedBits(N->getOperand(0), DemandedMask, DCI)) ||
+ (SimplifyDemandedBits(N->getOperand(1), DemandedMask, DCI)))
+ return SDValue();
+ break;
+ }
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) {
OpenPOWER on IntegriCloud