diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 72 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.td | 19 |
2 files changed, 89 insertions, 2 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 27dd4249770..703cccb3dbf 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -703,9 +703,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (MVT VT : MVT::vector_valuetypes()) { setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand); - setOperationAction(ISD::MULHS, VT, Expand); + if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32) { + setOperationAction(ISD::MULHS, VT, Custom); + setOperationAction(ISD::MULHU, VT, Custom); + } else { + setOperationAction(ISD::MULHS, VT, Expand); + setOperationAction(ISD::MULHU, VT, Expand); + } setOperationAction(ISD::SMUL_LOHI, VT, Expand); - setOperationAction(ISD::MULHU, VT, Expand); setOperationAction(ISD::UMUL_LOHI, VT, Expand); setOperationAction(ISD::BSWAP, VT, Expand); @@ -2549,6 +2554,66 @@ static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)); } +// Lower vector multiply high (ISD::MULHS and ISD::MULHU). +static SDValue LowerMULH(SDValue Op, SelectionDAG &DAG) { + // Multiplications are only custom-lowered for 128-bit vectors so that + // {S,U}MULL{2} can be detected. Otherwise v2i64 multiplications are not + // legal. + EVT VT = Op.getValueType(); + assert(VT.is128BitVector() && VT.isInteger() && + "unexpected type for custom-lowering ISD::MULH{U,S}"); + + SDValue V0 = Op.getOperand(0); + SDValue V1 = Op.getOperand(1); + + SDLoc DL(Op); + + EVT ExtractVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + + // We turn (V0 mulhs/mulhu V1) to: + // + // (uzp2 (smull (extract_subvector (ExtractVT V128:V0, (i64 0)), + // (extract_subvector (ExtractVT V128:V1, (i64 0))))), + // (smull (extract_subvector (ExtractVT V128:V0, (i64 VMull2Idx)), + // (extract_subvector (ExtractVT V128:V2, (i64 VMull2Idx)))))) + // + // Where ExtractVT is a subvector with half number of elements, and + // VMullIdx2 is the index of the middle element (the high part). + // + // The vector hight part extract and multiply will be matched against + // {S,U}MULL{v16i8_v8i16,v8i16_v4i32,v4i32_v2i64} which in turn will + // issue a {s}mull2 instruction. + // + // This basically multiply the lower subvector with '{s,u}mull', the high + // subvector with '{s,u}mull2', and shuffle both results high part in + // resulting vector. + unsigned Mull2VectorIdx = VT.getVectorNumElements () / 2; + SDValue VMullIdx = DAG.getConstant(0, DL, MVT::i64); + SDValue VMull2Idx = DAG.getConstant(Mull2VectorIdx, DL, MVT::i64); + + SDValue VMullV0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, V0, VMullIdx); + SDValue VMullV1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, V1, VMullIdx); + + SDValue VMull2V0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, V0, VMull2Idx); + SDValue VMull2V1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, V1, VMull2Idx); + + unsigned MullOpc = Op.getOpcode() == ISD::MULHS ? AArch64ISD::SMULL + : AArch64ISD::UMULL; + + EVT MullVT = ExtractVT.widenIntegerVectorElementType(*DAG.getContext()); + SDValue Mull = DAG.getNode(MullOpc, DL, MullVT, VMullV0, VMullV1); + SDValue Mull2 = DAG.getNode(MullOpc, DL, MullVT, VMull2V0, VMull2V1); + + Mull = DAG.getNode(ISD::BITCAST, DL, VT, Mull); + Mull2 = DAG.getNode(ISD::BITCAST, DL, VT, Mull2); + + return DAG.getNode(AArch64ISD::UZP2, DL, VT, Mull, Mull2); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue(); @@ -2681,6 +2746,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerFSINCOS(Op, DAG); case ISD::MUL: return LowerMUL(Op, DAG); + case ISD::MULHS: + case ISD::MULHU: + return LowerMULH(Op, DAG); case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG); case ISD::VECREDUCE_ADD: diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index d7b9a804247..273fd0a7fd9 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -3773,6 +3773,25 @@ defm : Neon_mul_widen_patterns<AArch64smull, SMULLv8i8_v8i16, defm : Neon_mul_widen_patterns<AArch64umull, UMULLv8i8_v8i16, UMULLv4i16_v4i32, UMULLv2i32_v2i64>; +// Patterns for smull2/umull2. +multiclass Neon_mul_high_patterns<SDPatternOperator opnode, + Instruction INST8B, Instruction INST4H, Instruction INST2S> { + def : Pat<(v8i16 (opnode (extract_high_v16i8 V128:$Rn), + (extract_high_v16i8 V128:$Rm))), + (INST8B V128:$Rn, V128:$Rm)>; + def : Pat<(v4i32 (opnode (extract_high_v8i16 V128:$Rn), + (extract_high_v8i16 V128:$Rm))), + (INST4H V128:$Rn, V128:$Rm)>; + def : Pat<(v2i64 (opnode (extract_high_v4i32 V128:$Rn), + (extract_high_v4i32 V128:$Rm))), + (INST2S V128:$Rn, V128:$Rm)>; +} + +defm : Neon_mul_high_patterns<AArch64smull, SMULLv16i8_v8i16, + SMULLv8i16_v4i32, SMULLv4i32_v2i64>; +defm : Neon_mul_high_patterns<AArch64umull, UMULLv16i8_v8i16, + UMULLv8i16_v4i32, UMULLv4i32_v2i64>; + // Additional patterns for SMLAL/SMLSL and UMLAL/UMLSL multiclass Neon_mulacc_widen_patterns<SDPatternOperator opnode, Instruction INST8B, Instruction INST4H, Instruction INST2S> { |