diff options
Diffstat (limited to 'llvm/lib/Target/ARM/ARMISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/ARM/ARMISelLowering.cpp | 53 |
1 files changed, 48 insertions, 5 deletions
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 222b5bca7a6..323e900a5f7 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -258,6 +258,7 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setOperationAction(ISD::UMIN, VT, Legal); setOperationAction(ISD::UMAX, VT, Legal); setOperationAction(ISD::ABS, VT, Legal); + setOperationAction(ISD::SETCC, VT, Custom); // No native support for these. setOperationAction(ISD::UDIV, VT, Expand); @@ -334,6 +335,12 @@ void ARMTargetLowering::addMVEVectorTypes(bool HasMVEFP) { setTruncStoreAction(MVT::v4i32, MVT::v4i16, Legal); setTruncStoreAction(MVT::v4i32, MVT::v4i8, Legal); setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal); + + // Predicate types + const MVT pTypes[] = {MVT::v16i1, MVT::v8i1, MVT::v4i1}; + for (auto VT : pTypes) { + addRegisterClass(VT, &ARM::VCCRRegClass); + } } ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM, @@ -1500,6 +1507,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const { case ARMISD::VCEQ: return "ARMISD::VCEQ"; case ARMISD::VCEQZ: return "ARMISD::VCEQZ"; + case ARMISD::VCNE: return "ARMISD::VCNE"; + case ARMISD::VCNEZ: return "ARMISD::VCNEZ"; case ARMISD::VCGE: return "ARMISD::VCGE"; case ARMISD::VCGEZ: return "ARMISD::VCGEZ"; case ARMISD::VCLEZ: return "ARMISD::VCLEZ"; @@ -1601,6 +1610,11 @@ EVT ARMTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &, EVT VT) const { if (!VT.isVector()) return getPointerTy(DL); + + // MVE has a predicate register. + if (Subtarget->hasMVEIntegerOps() && + (VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8)) + return MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); return VT.changeVectorElementTypeToInteger(); } @@ -5849,7 +5863,8 @@ static SDValue Expand64BitShift(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Lo, Hi); } -static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) { +static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG, + const ARMSubtarget *ST) { SDValue TmpOp0, TmpOp1; bool Invert = false; bool Swap = false; @@ -5858,11 +5873,23 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); SDValue CC = Op.getOperand(2); - EVT CmpVT = Op0.getValueType().changeVectorElementTypeToInteger(); EVT VT = Op.getValueType(); ISD::CondCode SetCCOpcode = cast<CondCodeSDNode>(CC)->get(); SDLoc dl(Op); + EVT CmpVT; + if (ST->hasNEON()) + CmpVT = Op0.getValueType().changeVectorElementTypeToInteger(); + else { + assert(ST->hasMVEIntegerOps() && + "No hardware support for integer vector comparison!"); + + if (Op.getValueType().getVectorElementType() != MVT::i1) + return SDValue(); + + CmpVT = VT; + } + if (Op0.getValueType().getVectorElementType() == MVT::i64 && (SetCCOpcode == ISD::SETEQ || SetCCOpcode == ISD::SETNE)) { // Special-case integer 64-bit equality comparisons. They aren't legal, @@ -5930,7 +5957,12 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) { // Integer comparisons. switch (SetCCOpcode) { default: llvm_unreachable("Illegal integer comparison"); - case ISD::SETNE: Invert = true; LLVM_FALLTHROUGH; + case ISD::SETNE: + if (ST->hasMVEIntegerOps()) { + Opc = ARMISD::VCNE; break; + } else { + Invert = true; LLVM_FALLTHROUGH; + } case ISD::SETEQ: Opc = ARMISD::VCEQ; break; case ISD::SETLT: Swap = true; LLVM_FALLTHROUGH; case ISD::SETGT: Opc = ARMISD::VCGT; break; @@ -5943,7 +5975,7 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) { } // Detect VTST (Vector Test Bits) = icmp ne (and (op0, op1), zero). - if (Opc == ARMISD::VCEQ) { + if (ST->hasNEON() && Opc == ARMISD::VCEQ) { SDValue AndOp; if (ISD::isBuildVectorAllZeros(Op1.getNode())) AndOp = Op0; @@ -5982,6 +6014,9 @@ static SDValue LowerVSETCC(SDValue Op, SelectionDAG &DAG) { SDValue Result; if (SingleOp.getNode()) { switch (Opc) { + case ARMISD::VCNE: + assert(ST->hasMVEIntegerOps() && "Unexpected DAG node"); + Result = DAG.getNode(ARMISD::VCNEZ, dl, CmpVT, SingleOp); break; case ARMISD::VCEQ: Result = DAG.getNode(ARMISD::VCEQZ, dl, CmpVT, SingleOp); break; case ARMISD::VCGE: @@ -8436,7 +8471,7 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::CTTZ: case ISD::CTTZ_ZERO_UNDEF: return LowerCTTZ(Op.getNode(), DAG, Subtarget); case ISD::CTPOP: return LowerCTPOP(Op.getNode(), DAG, Subtarget); - case ISD::SETCC: return LowerVSETCC(Op, DAG); + case ISD::SETCC: return LowerVSETCC(Op, DAG, Subtarget); case ISD::SETCCCARRY: return LowerSETCCCARRY(Op, DAG); case ISD::ConstantFP: return LowerConstantFP(Op, DAG, Subtarget); case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG, Subtarget); @@ -13594,6 +13629,14 @@ bool ARMTargetLowering::allowsMisalignedMemoryAccesses(EVT VT, unsigned, if (!Subtarget->hasMVEIntegerOps()) return false; + + // These are for predicates + if ((Ty == MVT::v16i1 || Ty == MVT::v8i1 || Ty == MVT::v4i1)) { + if (Fast) + *Fast = true; + return true; + } + if (Ty != MVT::v16i8 && Ty != MVT::v8i16 && Ty != MVT::v8f16 && Ty != MVT::v4i32 && Ty != MVT::v4f32 && Ty != MVT::v2i64 && Ty != MVT::v2f64 && |