diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index b4a7352a046..6af2d9d9d06 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -478,6 +478,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::SINT_TO_FP); setTargetDAGCombine(ISD::UINT_TO_FP); + setTargetDAGCombine(ISD::FP_TO_SINT); + setTargetDAGCombine(ISD::FP_TO_UINT); + setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); setTargetDAGCombine(ISD::ANY_EXTEND); @@ -7529,6 +7532,70 @@ static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); } +/// Fold a floating-point multiply by power of two into floating-point to +/// fixed-point conversion. +static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + if (!Subtarget->hasNEON()) + return SDValue(); + + SDValue Op = N->getOperand(0); + if (!Op.getValueType().isVector() || Op.getOpcode() != ISD::FMUL) + return SDValue(); + + SDValue ConstVec = Op->getOperand(1); + if (!isa<BuildVectorSDNode>(ConstVec)) + return SDValue(); + + MVT FloatTy = Op.getSimpleValueType().getVectorElementType(); + uint32_t FloatBits = FloatTy.getSizeInBits(); + if (FloatBits != 32 && FloatBits != 64) + return SDValue(); + + MVT IntTy = N->getSimpleValueType(0).getVectorElementType(); + uint32_t IntBits = IntTy.getSizeInBits(); + if (IntBits != 16 && IntBits != 32 && IntBits != 64) + return SDValue(); + + // Avoid conversions where iN is larger than the float (e.g., float -> i64). + if (IntBits > FloatBits) + return SDValue(); + + BitVector UndefElements; + BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec); + int32_t Bits = IntBits == 64 ? 64 : 32; + int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, Bits + 1); + if (C == -1 || C == 0 || C > Bits) + return SDValue(); + + MVT ResTy; + unsigned NumLanes = Op.getValueType().getVectorNumElements(); + switch (NumLanes) { + default: + return SDValue(); + case 2: + ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64; + break; + case 4: + ResTy = MVT::v4i32; + break; + } + + SDLoc DL(N); + bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT; + unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfp2fxs + : Intrinsic::aarch64_neon_vcvtfp2fxu; + SDValue FixConv = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy, + DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), + Op->getOperand(0), DAG.getConstant(C, DL, MVT::i32)); + // We can handle smaller integers by generating an extra trunc. + if (IntBits < FloatBits) + FixConv = DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), FixConv); + + return FixConv; +} + /// An EXTR instruction is made up of two shifts, ORed together. This helper /// searches for and classifies those shifts. static bool findEXTRHalf(SDValue N, SDValue &Src, uint32_t &ShiftAmount, @@ -9400,6 +9467,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: return performIntToFpCombine(N, DAG, Subtarget); + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: + return performFpToIntCombine(N, DAG, Subtarget); case ISD::OR: return performORCombine(N, DCI, Subtarget); case ISD::INTRINSIC_WO_CHAIN: |